Skip to content
Snippets Groups Projects
Commit 5b7d4b73 authored by Mario Sänger's avatar Mario Sänger
Browse files

Add possibility to select target column

parent 0eb9f0ec
No related merge requests found
.idea/ *.idea/
**/_env **/_env
**/_logs
**/embeddings **/embeddings
*.pyc *.pyc
This diff is collapsed.
...@@ -12,8 +12,12 @@ from util import LoggingMixin ...@@ -12,8 +12,12 @@ from util import LoggingMixin
class KerasUtil(object): class KerasUtil(object):
@staticmethod @staticmethod
def best_model_checkpointing(model_name: str, monitor_loss: str = "loss"): def best_model_checkpointing_by_model_name(model_name: str, monitor_loss: str = "loss"):
best_model_file = os.path.join(AppContext.default().output_dir, "optimal_%s.h5" % model_name) best_model_file = os.path.join(AppContext.default().output_dir, "%s_best.h5" % model_name)
return ModelCheckpoint(filepath=best_model_file, monitor=monitor_loss, save_best_only=True, verbose=1)
@staticmethod
def best_model_checkpointing_by_file_path(best_model_file: str, monitor_loss: str = "loss"):
return ModelCheckpoint(filepath=best_model_file, monitor=monitor_loss, save_best_only=True, verbose=1) return ModelCheckpoint(filepath=best_model_file, monitor=monitor_loss, save_best_only=True, verbose=1)
...@@ -59,7 +63,7 @@ class ExtendedKerasClassifier(KerasClassifier, LoggingMixin): ...@@ -59,7 +63,7 @@ class ExtendedKerasClassifier(KerasClassifier, LoggingMixin):
else: else:
self.logger.debug("Model wasn't re-fitted -> re-using existing model") self.logger.debug("Model wasn't re-fitted -> re-using existing model")
pass pass
self.logger.info("Classifer has %s classes", len(self.classes_))
return super(ExtendedKerasClassifier, self).predict(x, **kwargs) return super(ExtendedKerasClassifier, self).predict(x, **kwargs)
def __getstate__(self): def __getstate__(self):
......
...@@ -20,6 +20,13 @@ class DataPreparationUtil(object): ...@@ -20,6 +20,13 @@ class DataPreparationUtil(object):
return MapFunction(column, _lower) return MapFunction(column, _lower)
@staticmethod
def strip(column: str):
def _strip(text):
return str(text).strip()
return MapFunction(column, _strip)
@staticmethod @staticmethod
def tokenize(text_column: str, token_column: str = "tokens"): def tokenize(text_column: str, token_column: str = "tokens"):
return SimpleTokenizer(text_column, token_column) return SimpleTokenizer(text_column, token_column)
...@@ -57,6 +64,20 @@ class DataPreparationUtil(object): ...@@ -57,6 +64,20 @@ class DataPreparationUtil(object):
return MapFunction(icd10_column, _extract, target_column) return MapFunction(icd10_column, _extract, target_column)
@staticmethod
def extract_icd10_section(icd10_column: str, target_column: str):
def _extract(value):
return value.strip()[0:2].lower()
return MapFunction(icd10_column, _extract, target_column)
@staticmethod
def extract_icd10_subsection(icd10_column: str, target_column: str):
def _extract(value):
return value.strip()[0:3].lower()
return MapFunction(icd10_column, _extract, target_column)
@staticmethod @staticmethod
def extract_icd10_subchapter(icd10_column: str, target_column: str): def extract_icd10_subchapter(icd10_column: str, target_column: str):
def _extract(value): def _extract(value):
......
...@@ -7,3 +7,4 @@ sklearn ...@@ -7,3 +7,4 @@ sklearn
tensorflow tensorflow
tqdm tqdm
h5py h5py
cython
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment