Skip to content
Snippets Groups Projects
Commit d0c34b11 authored by Jurica Seva's avatar Jurica Seva
Browse files

Merge remote-tracking branch 'origin/master'

parents 39db3e78 9807cad9
Branches
No related merge requests found
.idea/ *.idea/
**/_env **/_env
**/_logs
**/embeddings **/embeddings
*.pyc *.pyc
code_mario/data
This diff is collapsed.
...@@ -12,8 +12,12 @@ from util import LoggingMixin ...@@ -12,8 +12,12 @@ from util import LoggingMixin
class KerasUtil(object): class KerasUtil(object):
@staticmethod @staticmethod
def best_model_checkpointing(model_name: str, monitor_loss: str = "loss"): def best_model_checkpointing_by_model_name(model_name: str, monitor_loss: str = "loss"):
best_model_file = os.path.join(AppContext.default().output_dir, "optimal_%s.h5" % model_name) best_model_file = os.path.join(AppContext.default().output_dir, "%s_best.h5" % model_name)
return ModelCheckpoint(filepath=best_model_file, monitor=monitor_loss, save_best_only=True, verbose=1)
@staticmethod
def best_model_checkpointing_by_file_path(best_model_file: str, monitor_loss: str = "loss"):
return ModelCheckpoint(filepath=best_model_file, monitor=monitor_loss, save_best_only=True, verbose=1) return ModelCheckpoint(filepath=best_model_file, monitor=monitor_loss, save_best_only=True, verbose=1)
...@@ -59,7 +63,7 @@ class ExtendedKerasClassifier(KerasClassifier, LoggingMixin): ...@@ -59,7 +63,7 @@ class ExtendedKerasClassifier(KerasClassifier, LoggingMixin):
else: else:
self.logger.debug("Model wasn't re-fitted -> re-using existing model") self.logger.debug("Model wasn't re-fitted -> re-using existing model")
pass pass
self.logger.info("Classifer has %s classes", len(self.classes_))
return super(ExtendedKerasClassifier, self).predict(x, **kwargs) return super(ExtendedKerasClassifier, self).predict(x, **kwargs)
def __getstate__(self): def __getstate__(self):
......
...@@ -20,6 +20,13 @@ class DataPreparationUtil(object): ...@@ -20,6 +20,13 @@ class DataPreparationUtil(object):
return MapFunction(column, _lower) return MapFunction(column, _lower)
@staticmethod
def strip(column: str):
def _strip(text):
return str(text).strip()
return MapFunction(column, _strip)
@staticmethod @staticmethod
def tokenize(text_column: str, token_column: str = "tokens"): def tokenize(text_column: str, token_column: str = "tokens"):
return SimpleTokenizer(text_column, token_column) return SimpleTokenizer(text_column, token_column)
...@@ -57,6 +64,20 @@ class DataPreparationUtil(object): ...@@ -57,6 +64,20 @@ class DataPreparationUtil(object):
return MapFunction(icd10_column, _extract, target_column) return MapFunction(icd10_column, _extract, target_column)
@staticmethod
def extract_icd10_section(icd10_column: str, target_column: str):
def _extract(value):
return value.strip()[0:2].lower()
return MapFunction(icd10_column, _extract, target_column)
@staticmethod
def extract_icd10_subsection(icd10_column: str, target_column: str):
def _extract(value):
return value.strip()[0:3].lower()
return MapFunction(icd10_column, _extract, target_column)
@staticmethod @staticmethod
def extract_icd10_subchapter(icd10_column: str, target_column: str): def extract_icd10_subchapter(icd10_column: str, target_column: str):
def _extract(value): def _extract(value):
......
...@@ -7,3 +7,4 @@ sklearn ...@@ -7,3 +7,4 @@ sklearn
tensorflow tensorflow
tqdm tqdm
h5py h5py
cython
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment