Skip to content
Snippets Groups Projects
Commit 5b7d4b73 authored by Mario Sänger's avatar Mario Sänger
Browse files

Add possibility to select target column

parent 0eb9f0ec
Branches
No related merge requests found
.idea/
*.idea/
**/_env
**/_logs
**/embeddings
*.pyc
This diff is collapsed.
......@@ -12,8 +12,12 @@ from util import LoggingMixin
class KerasUtil(object):
@staticmethod
def best_model_checkpointing(model_name: str, monitor_loss: str = "loss"):
best_model_file = os.path.join(AppContext.default().output_dir, "optimal_%s.h5" % model_name)
def best_model_checkpointing_by_model_name(model_name: str, monitor_loss: str = "loss"):
best_model_file = os.path.join(AppContext.default().output_dir, "%s_best.h5" % model_name)
return ModelCheckpoint(filepath=best_model_file, monitor=monitor_loss, save_best_only=True, verbose=1)
@staticmethod
def best_model_checkpointing_by_file_path(best_model_file: str, monitor_loss: str = "loss"):
return ModelCheckpoint(filepath=best_model_file, monitor=monitor_loss, save_best_only=True, verbose=1)
......@@ -59,7 +63,7 @@ class ExtendedKerasClassifier(KerasClassifier, LoggingMixin):
else:
self.logger.debug("Model wasn't re-fitted -> re-using existing model")
pass
self.logger.info("Classifer has %s classes", len(self.classes_))
return super(ExtendedKerasClassifier, self).predict(x, **kwargs)
def __getstate__(self):
......
......@@ -20,6 +20,13 @@ class DataPreparationUtil(object):
return MapFunction(column, _lower)
@staticmethod
def strip(column: str):
def _strip(text):
return str(text).strip()
return MapFunction(column, _strip)
@staticmethod
def tokenize(text_column: str, token_column: str = "tokens"):
return SimpleTokenizer(text_column, token_column)
......@@ -57,6 +64,20 @@ class DataPreparationUtil(object):
return MapFunction(icd10_column, _extract, target_column)
@staticmethod
def extract_icd10_section(icd10_column: str, target_column: str):
def _extract(value):
return value.strip()[0:2].lower()
return MapFunction(icd10_column, _extract, target_column)
@staticmethod
def extract_icd10_subsection(icd10_column: str, target_column: str):
def _extract(value):
return value.strip()[0:3].lower()
return MapFunction(icd10_column, _extract, target_column)
@staticmethod
def extract_icd10_subchapter(icd10_column: str, target_column: str):
def _extract(value):
......
......@@ -7,3 +7,4 @@ sklearn
tensorflow
tqdm
h5py
cython
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment