From fdb52c125fee5395e4c19fc0e546fb3e4fb59b0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20Sa=CC=88nger?= <mario.saenger@student.hu-berlin.de> Date: Tue, 8 May 2018 11:24:25 +0200 Subject: [PATCH] Reduce complexity of sampling strategy --- code_mario/clef18_task1_base.py | 18 +++++++++++------- code_mario/clef18_task1_emb1.py | 18 +++++++++--------- code_mario/ft_embeddings.py | 5 ++++- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/code_mario/clef18_task1_base.py b/code_mario/clef18_task1_base.py index fa0eff5..ac7dc44 100644 --- a/code_mario/clef18_task1_base.py +++ b/code_mario/clef18_task1_base.py @@ -336,13 +336,17 @@ class NegativeSampling(LoggingMixin): if len(chapter_samples) > 0: chapter_samples = chapter_samples.sample(min(num_chapter_samples, len(chapter_samples))) - section_samples = dictionary_df.query("ICD10 != '%s' & ICD10_section == '%s'" % (icd10_code, icd10_section)) - if len(section_samples) > 0: - section_samples = section_samples.sample(min(num_section_samples, len(section_samples))) - - subsection_samples = dictionary_df.query("ICD10 != '%s' & ICD10_subsection == '%s'" % (icd10_code, icd10_subsection)) - if len(subsection_samples) > 0: - subsection_samples = subsection_samples.sample(min(num_subsection_samples, len(subsection_samples))) + # section_samples = dictionary_df.query("ICD10 != '%s' & ICD10_section == '%s'" % (icd10_code, icd10_section)) + # if len(section_samples) > 0: + # section_samples = section_samples.sample(min(num_section_samples, len(section_samples))) + # + # subsection_samples = dictionary_df.query("ICD10 != '%s' & ICD10_subsection == '%s'" % (icd10_code, icd10_subsection)) + # if len(subsection_samples) > 0: + # subsection_samples = subsection_samples.sample(min(num_subsection_samples, len(subsection_samples))) + # + + section_samples = chapter_samples.sample(1) + subsection_samples = chapter_samples.sample(1) exp_sim_samples = num_chapter_samples + num_section_samples + num_subsection_samples act_sim_samples = len(chapter_samples) + len(section_samples) + len(subsection_samples) diff --git a/code_mario/clef18_task1_emb1.py b/code_mario/clef18_task1_emb1.py index c6a41b6..8d00d18 100644 --- a/code_mario/clef18_task1_emb1.py +++ b/code_mario/clef18_task1_emb1.py @@ -88,14 +88,14 @@ class Clef18Task1Emb1(Clef18Task1Base): return model def train_embedding_model(self, config: Emb1Configuration, ft_model: FastTextModel, neg_sampling_strategy: Callable, epochs: int, batch_size: int) -> Model: - self.logger.info("Start building training pairs") - train_pair_data = self.build_pairs(config.train_cert_df, config.dict_df, neg_sampling_strategy) - self.logger.info("Label distribution:\n%s", train_pair_data["Label"].value_counts()) - self.logger.info("Start building embedding model") model = self.build_embedding_model(config.keras_tokenizer.word_index, ft_model, config.max_cert_length, config.max_dict_length) model.summary(print_fn=self.logger.info) + self.logger.info("Start building training pairs") + train_pair_data = self.build_pairs(config.train_cert_df, config.dict_df, neg_sampling_strategy) + self.logger.info("Label distribution:\n%s", train_pair_data["Label"].value_counts()) + cert_inputs = pad_sequences(train_pair_data["Cert_input"].values, maxlen=config.max_cert_length, padding="post") dict_inputs = pad_sequences(train_pair_data["Dict_input"].values, maxlen=config.max_dict_length, padding="post") labels = train_pair_data["Label"].values @@ -346,10 +346,10 @@ if __name__ == "__main__": train_emb_parser.add_argument("--neg_sampling", help="Negative sampling strategy to use", default="ext1", choices=["def", "ext1"]) train_emb_parser.add_argument("--num_neg_samples", help="Number of negative samples to use (default strategy)", default=75, type=int) - train_emb_parser.add_argument("--num_neg_cha", help="Number of negative chapter samples to use (ext1 strategy)", default=20, type=int) + train_emb_parser.add_argument("--num_neg_cha", help="Number of negative chapter samples to use (ext1 strategy)", default=10, type=int) train_emb_parser.add_argument("--num_neg_sec", help="Number of negative section samples to use (ext1 strategy)", default=20, type=int) train_emb_parser.add_argument("--num_neg_sub", help="Number of negative subsection samples to use (ext1 strategy)", default=20, type=int) - train_emb_parser.add_argument("--num_neg_oth", help="Number of negative other samples to use (ext1 strategy)", default=45, type=int) + train_emb_parser.add_argument("--num_neg_oth", help="Number of negative other samples to use (ext1 strategy)", default=10, type=int) eval_classifier_parser = subparsers.add_parser("eval-cl") eval_classifier_parser.add_argument("emb_model", help="Path to the embedding model to use") @@ -380,10 +380,10 @@ if __name__ == "__main__": if args.mode == "train-emb": ft_embeddings = FastTextEmbeddings() - # ft_model = ft_embeddings.load_embeddings_by_id(args.lang) + ft_model = ft_embeddings.load_embeddings_by_id(args.lang) - sentences = [["cat", "say", "meow"], ["dog", "say", "woof"]] - ft_model = FastTextModel("dummy", [FastText(sentences, min_count=1)]) + #sentences = [["cat", "say", "meow"], ["dog", "say", "woof"]] + #ft_model = FastTextModel("dummy", [FastText(sentences, min_count=1)]) if args.single_only: certificates = clef_data.filter_single_code_lines(certificates) diff --git a/code_mario/ft_embeddings.py b/code_mario/ft_embeddings.py index 59bc279..f2b3b7f 100644 --- a/code_mario/ft_embeddings.py +++ b/code_mario/ft_embeddings.py @@ -20,7 +20,7 @@ class FastTextModel(LoggingMixin): for ft_model in self.ft_models: try: embeddings.append(ft_model[word]) - except KeyError: + except KeyError as error: self.logger.warn("Can't create embedding for " + word) embeddings.append(np.zeros(ft_model.vector_size)) @@ -38,10 +38,13 @@ class FastTextEmbeddings(LoggingMixin): def load_embeddings_by_id(self, id: str) -> FastTextModel: if id == "it": return FastTextModel("it", [self.load_it_embeddings()]) + elif id == "hu": return FastTextModel("hu", [self.load_hu_embeddings()]) + elif id == "fr": return FastTextModel("fr", [self.load_fr_embeddings()]) + elif id == "all-con": return FastTextModel("all-con", [self.load_fr_embeddings(), self.load_it_embeddings(), -- GitLab