Skip to content
Snippets Groups Projects
Commit 2554815d authored by Mario Sänger's avatar Mario Sänger
Browse files

Minor fixes to seq2seq logging

parent 8ca1a36a
No related merge requests found
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="guppi" createEmptyFolders="true" persistUploadOnCheckin="false" autoUploadExternalChanges="true">
<serverData>
<paths name="guppi">
<serverdata>
<mappings>
<mapping deploy="/projects/clef18" local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="sonic">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="vistTriton">
<serverdata>
<mappings>
<mapping deploy="/clef18" local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ALWAYS" />
</component>
</project>
\ No newline at end of file
......@@ -2,6 +2,7 @@
# experiment=Experiment(api_key="hSd9vTj0EfMu72569YnVEvtvj")
# from loader import *
import keras.backend as K
from util import *
import numpy as np
import random
......@@ -27,14 +28,17 @@ config=tf.ConfigProto()
config.gpu_options.allow_growth=True
config.gpu_options.allocator_type='BFC'
sess = tf.Session(graph=tf.get_default_graph(), config=config)
K.set_session(sess)
callbacks_list=[
EarlyStopping(
monitor='val_loss',
patience=2,
patience=50,
),
ModelCheckpoint(
filepath='models/icd10Classification.h5',
monitor='val_loss',
monitor='loss',
save_best_only=True,
),
CSVLogger(
......@@ -44,7 +48,7 @@ callbacks_list=[
]
latent_dim = 512
epochs = 100
epochs = 500
batch_size = 1000
tokenizer=TokenizePreprocessor()
......@@ -115,7 +119,8 @@ try:
batch_size=batch_size,
epochs=epochs,
callbacks=callbacks_list,
validation_split=0.25
validation_data=[X_test, Y_test]
#verbose=0
)
except Exception as e:
......
......@@ -48,10 +48,10 @@ except OSError:
from classificationICD10 import *
icd10_model = keras_load_model('models/icd10Classification_attention_extended.h5')
with open('models/icd10_tokenizer_extended.p', 'rb') as handle:
with open('models/icd10_tokenizer.p', 'rb') as handle:
icd10Tokenizer = pickle.load(handle)
with open('models/icd10_mappings_extended.p', 'rb') as handle:
with open('models/icd10_mappings.p', 'rb') as handle:
encoded_Y = pickle.load(handle)
# LOAD ICD 10 CLASSIFICATION MODEL
......
......@@ -45,17 +45,17 @@ except OSError:
from classificationICD10 import *
icd10_model = keras_load_model('models/icd10Classification_attention_extended.h5')
with open('models/icd10_tokenizer_extended.p', 'rb') as handle:
with open('models/icd10_tokenizer.p', 'rb') as handle:
icd10Tokenizer = pickle.load(handle)
with open('models/icd10_mappings_extended.p', 'rb') as handle:
with open('models/icd10_mappings.p', 'rb') as handle:
encoded_Y = pickle.load(handle)
# LOAD ICD 10 CLASSIFICATION MODEL
callbacks_list = [
EarlyStopping(
monitor='val_loss',
patience=2,
patience=5,
# min_delta=0.001
),
ModelCheckpoint(
......@@ -71,7 +71,7 @@ callbacks_list = [
latent_dim = 256
batch_size = 400
epochs = 1
epochs = 2000
train_data_generator = KerasBatchGenerator(batch_size,
source_train,
......@@ -237,7 +237,7 @@ run_pipeline_prediction(source_val, decode_seq, icd10_model, encoded_Y, labels_v
source_kerasTokenizer, source_max_sequence_tokenizer,
icd10Tokenizer, max_icd10_length,
target_index_to_word_dict, target_val,
'logs/seq2seq')
'logs/seq2seq-att')
# y_true = []
# y_pred = []
......
......@@ -55,13 +55,13 @@ def run_pipeline_prediction(sentences: List[str], decode_seq_fnc: Callable, icd1
pred_text = " ".join(translated_sent)
pred_texts.append(pred_text)
pred_indexes = " ".join(translated_index)
pred_indexes = " ".join([str(i) for i in translated_index])
pred_ids.append(pred_indexes)
gold_indexes = np.trim_zeros(gold_target_indexes[seq_index], 'b')[1:-1]
gold_ids.append(" ".join(gold_indexes))
gold_ids.append(" ".join([str(i) for i in gold_indexes]))
gold_text = " ".join([target_index_to_word_dict[x] for x in gold_indexes])
gold_text = " ".join([target_index_to_word_dict[x] for x in gold_indexes if x in target_index_to_word_dict])
gold_texts.append(gold_text)
print('Target indexes:', gold_indexes)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment