Skip to content
Snippets Groups Projects
Commit b202a46b authored by Jurica Seva's avatar Jurica Seva
Browse files

Added ICD 10 classification with attention (best results), attention with...

Added ICD 10 classification with attention (best results), attention with context (worst results) and char level classification (running atm).
parent e6a59bea
No related merge requests found
......@@ -11,6 +11,7 @@ from keras.utils import to_categorical
kerasTokenizer = Tokenizer()
tokenizer = TokenizePreprocessor()
prepareData = prepareData()
SEED = 777
frCorpora, frErrors = prepareData.prepareData('FR')
itCorpora, itErrors = prepareData.prepareData('IT')
......
......@@ -4,6 +4,8 @@ from _layers import Attention
from keras.models import Model, load_model as keras_load_model
from keras.layers import Input
from loader import *
# ICD 10 STUFF
icd10_model = keras_load_model('models/icd10Classification_attention.h5', custom_objects={'Attention':Attention})
with open('models/icd10_tokenizer.p', 'rb') as handle:
......@@ -17,10 +19,13 @@ with open('models/icd10_mappings.p', 'rb') as handle:
S2S_model = keras_load_model('models/s2s.h5', custom_objects={'Attention':Attention})
with open('models/s2s_source_tokenizer.p', 'rb') as handle:
s2s_source_tokenizer = pickle.load(handle)
source_vocab = s2s_source_tokenizer.word_index
source_index_to_word_dict = {v:k.strip() for k,v in s2s_source_tokenizer.word_index.items()}
with open('models/s2s_target_tokenizer.p', 'rb') as handle:
s2s_target_tokenizer = pickle.load(handle)
target_vocab =s2s_target_tokenizer.word_index
target_index_to_word_dict = {v:k.strip() for k,v in s2s_target_tokenizer.word_index.items()}
# S2S STUFF
......@@ -31,15 +36,15 @@ x, state_h, state_c = S2S_model.get_layer('lstm_1').output
encoder_states = [state_h, state_c]
embed_2 = S2S_model.get_layer('embedding_2').output
decoder_LSTM = S2S_model.get_layer('lstm_2').output
decoder_dense = S2S_model.get_layer('dense_1').output
decoder_LSTM = S2S_model.get_layer('lstm_2')
decoder_dense = S2S_model.get_layer('dense_1')
# Encoder inference model
encoder_model_inf = Model(encoder_input, encoder_states)
# Decoder inference model
decoder_state_input_h = Input(shape=(256,))
decoder_state_input_c = Input(shape=(256,))
decoder_state_input_h = Input(shape=(256,), name='inf_input1')
decoder_state_input_c = Input(shape=(256,), name='inf_input2')
decoder_input_states = [decoder_state_input_h, decoder_state_input_c]
decoder_out, decoder_h, decoder_c = decoder_LSTM(embed_2, initial_state=decoder_input_states)
......@@ -49,10 +54,14 @@ decoder_out = decoder_dense(decoder_out)
decoder_model_inf = Model(inputs=[decoder_input] + decoder_input_states,
outputs=[decoder_out] + decoder_states )
encoder_model_inf.summary()
decoder_model_inf.summary()
def decode_seq(inp_seq):
states_val = encoder_model_inf.predict(inp_seq)
target_seq = np.zeros((1, target_max_sequence))
target_seq = np.zeros((1, S2S_model.get_layer('input_2').output_shape[1]))
target_seq[0, 0] = target_vocab['sos']
translated_sent = []
......@@ -65,27 +74,33 @@ def decode_seq(inp_seq):
try:
sampled_fra_char = target_index_to_word_dict[max_val_index]
except KeyError:
# stop_condition = True
sampled_fra_char = 'eos'
translated_sent.append(sampled_fra_char)
translated_index.append(max_val_index)
if ((sampled_fra_char == 'eos') or (len(translated_sent) > target_max_sequence)):
if ((sampled_fra_char == 'eos') or (len(translated_sent) > S2S_model.get_layer('input_2').output_shape[1])):
stop_condition = True
target_seq = np.zeros((1, target_max_sequence))
target_seq = np.zeros((1, S2S_model.get_layer('input_2').output_shape[1]))
target_seq[0, 0] = max_val_index
states_val = [decoder_h, decoder_c]
return translated_sent[:-1], translated_index[:-1]
y_true = []
y_pred = []
for seq_index in tqdm.tqdm(range(len(source_val))):
# for seq_index in range(10):
inp_seq = source_val[seq_index:seq_index+1]
# for seq_index in range(len(source_corpus)):
for seq_index in range(10):
inp_seq = source_corpus[seq_index:seq_index + 1]
inp_seq = s2s_source_tokenizer.texts_to_sequences(inp_seq)
inp_seq = pad_sequences(inp_seq, maxlen=S2S_model.get_layer('input_1').output_shape[1], padding='post')
translated_sent, translated_index= decode_seq(inp_seq)
target_seq = target_corpus[seq_index:seq_index + 1]
target_seq = s2s_target_tokenizer.texts_to_sequences(target_seq)
# PREDICT ICD10
source_word_sequence = kerasTokenizer.texts_to_sequences([" ".join(translated_sent)])
word_sequence = pad_sequences(source_word_sequence, maxlen=icd10_model.layers[0].input_shape[1], padding='post')
......@@ -93,23 +108,23 @@ for seq_index in tqdm.tqdm(range(len(source_val))):
# print(icd10_code_index, type(icd10_code_index))
max_val_index = np.argmax(icd10_code_index, axis=1)[0]
# print(max_val_index)
icd10_label = encoded_Y.inverse_transform(max_val_index)
# print('-')
# target_index = np.trim_zeros(target_val[seq_index], 'b')[1:-1]
# print('Target indexes:', target_index)
# print('Decoded indexes:', translated_index)
#
# print('Target sentence:', " ".join([target_index_to_word_dict[x] for x in target_index]))
# print('Decoded sentence:', " ".join(translated_sent))
#
# print('Target ICD-10:', labels_val[seq_index])
# print('Predict ICD-10:', icd10_label)
y_true.append(labels_val[seq_index])
icd10_label = icd10Encoder.inverse_transform(max_val_index)
print('-')
target_index = target_seq[0]
print('Target indexes:', target_index)
print('Decoded indexes:', translated_index)
print('Target sentence:', " ".join([target_index_to_word_dict[x] for x in target_index]))
print('Decoded sentence:', " ".join(translated_sent))
print('Target ICD-10:', labels[seq_index])
print('Predict ICD-10:', icd10_label)
y_true.append(labels[seq_index])
y_pred.append(icd10_label)
report = classification_report(y_true, y_pred)
report_df = report_to_df(report)
report_df.to_csv('logs/classification_report.csv')
report_df.to_csv('logs/classification_report_test.csv')
print(report_df)
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment