From 869d45a89da3a0c68c9f071576b58009043645f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20Sa=CC=88nger?= <mario.saenger@student.hu-berlin.de> Date: Fri, 4 May 2018 23:53:19 +0200 Subject: [PATCH] Save program arguments to file --- code_mario/clef18_task1.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/code_mario/clef18_task1.py b/code_mario/clef18_task1.py index 600866b..425ca4f 100644 --- a/code_mario/clef18_task1.py +++ b/code_mario/clef18_task1.py @@ -516,6 +516,15 @@ class Clef18Task1V2(LoggingMixin): result_writer.write("%s\t%s\t%s\t%s\n" % (r.target_label, r.classifier_name, r.data_set_name, r.accuracy)) result_writer.close() + def save_arguments(self, arguments: Namespace): + arguments_file = os.path.join(AppContext.default().log_dir, "arguments.txt") + self.logger.info("Saving arguments to " + arguments_file) + + with open(arguments_file, 'w', encoding="utf8") as writer: + for key, value in arguments.__dict__.items(): + writer.write("%s=%s\n" % (str(key), str(value))) + writer.close() + def save_configuration(self, configuration: Configuration): label_encoder_file = os.path.join(AppContext.default().output_dir, "label_encoder.pk") self.logger.info("Saving label encoder to " + label_encoder_file) @@ -547,7 +556,6 @@ class Clef18Task1V2(LoggingMixin): self.logger.info("Reloading embedding model from " + emb_model_file) return k.models.load_model(args.emb_model) - def create_dnn_classifier(self, model_name, label: str, val_data: Tuple, **kwargs): if val_data is not None: monitor_loss = "val_loss" @@ -677,14 +685,16 @@ if __name__ == "__main__": ft_model = ft_embeddings.load_embeddings_by_language(args.lang) clef18_task1 = Clef18Task1V2() - neg_sampling = NegativeSampling() + clef18_task1.save_arguments(args) if args.mode == "train-emb": configuration = clef18_task1.prepare_data_set(certificates, dictionary, ft_model, args.train_ratio, args.val_ratio,args.strat_column, args.samples, args.strat_splits) - neg_sampling_strategy = neg_sampling.get_strategy_by_name(args.neg_sampling, args) clef18_task1.save_configuration(configuration) + neg_sampling = NegativeSampling() + neg_sampling_strategy = neg_sampling.get_strategy_by_name(args.neg_sampling, args) + embedding_model = clef18_task1.train_embedding_model(configuration, ft_model, neg_sampling_strategy, args.epochs, args.batch_size) elif args.mode == "eval-cl": -- GitLab