Skip to content
Snippets Groups Projects
Commit b202a46b authored by Jurica Seva's avatar Jurica Seva
Browse files

Added ICD 10 classification with attention (best results), attention with...

Added ICD 10 classification with attention (best results), attention with context (worst results) and char level classification (running atm).
parent e6a59bea
No related merge requests found
...@@ -11,6 +11,7 @@ from keras.utils import to_categorical ...@@ -11,6 +11,7 @@ from keras.utils import to_categorical
kerasTokenizer = Tokenizer() kerasTokenizer = Tokenizer()
tokenizer = TokenizePreprocessor() tokenizer = TokenizePreprocessor()
prepareData = prepareData() prepareData = prepareData()
SEED = 777
frCorpora, frErrors = prepareData.prepareData('FR') frCorpora, frErrors = prepareData.prepareData('FR')
itCorpora, itErrors = prepareData.prepareData('IT') itCorpora, itErrors = prepareData.prepareData('IT')
......
...@@ -4,6 +4,8 @@ from _layers import Attention ...@@ -4,6 +4,8 @@ from _layers import Attention
from keras.models import Model, load_model as keras_load_model from keras.models import Model, load_model as keras_load_model
from keras.layers import Input from keras.layers import Input
from loader import *
# ICD 10 STUFF # ICD 10 STUFF
icd10_model = keras_load_model('models/icd10Classification_attention.h5', custom_objects={'Attention':Attention}) icd10_model = keras_load_model('models/icd10Classification_attention.h5', custom_objects={'Attention':Attention})
with open('models/icd10_tokenizer.p', 'rb') as handle: with open('models/icd10_tokenizer.p', 'rb') as handle:
...@@ -17,10 +19,13 @@ with open('models/icd10_mappings.p', 'rb') as handle: ...@@ -17,10 +19,13 @@ with open('models/icd10_mappings.p', 'rb') as handle:
S2S_model = keras_load_model('models/s2s.h5', custom_objects={'Attention':Attention}) S2S_model = keras_load_model('models/s2s.h5', custom_objects={'Attention':Attention})
with open('models/s2s_source_tokenizer.p', 'rb') as handle: with open('models/s2s_source_tokenizer.p', 'rb') as handle:
s2s_source_tokenizer = pickle.load(handle) s2s_source_tokenizer = pickle.load(handle)
source_vocab = s2s_source_tokenizer.word_index
source_index_to_word_dict = {v:k.strip() for k,v in s2s_source_tokenizer.word_index.items()} source_index_to_word_dict = {v:k.strip() for k,v in s2s_source_tokenizer.word_index.items()}
with open('models/s2s_target_tokenizer.p', 'rb') as handle: with open('models/s2s_target_tokenizer.p', 'rb') as handle:
s2s_target_tokenizer = pickle.load(handle) s2s_target_tokenizer = pickle.load(handle)
target_vocab =s2s_target_tokenizer.word_index
target_index_to_word_dict = {v:k.strip() for k,v in s2s_target_tokenizer.word_index.items()} target_index_to_word_dict = {v:k.strip() for k,v in s2s_target_tokenizer.word_index.items()}
# S2S STUFF # S2S STUFF
...@@ -31,15 +36,15 @@ x, state_h, state_c = S2S_model.get_layer('lstm_1').output ...@@ -31,15 +36,15 @@ x, state_h, state_c = S2S_model.get_layer('lstm_1').output
encoder_states = [state_h, state_c] encoder_states = [state_h, state_c]
embed_2 = S2S_model.get_layer('embedding_2').output embed_2 = S2S_model.get_layer('embedding_2').output
decoder_LSTM = S2S_model.get_layer('lstm_2').output decoder_LSTM = S2S_model.get_layer('lstm_2')
decoder_dense = S2S_model.get_layer('dense_1').output decoder_dense = S2S_model.get_layer('dense_1')
# Encoder inference model # Encoder inference model
encoder_model_inf = Model(encoder_input, encoder_states) encoder_model_inf = Model(encoder_input, encoder_states)
# Decoder inference model # Decoder inference model
decoder_state_input_h = Input(shape=(256,)) decoder_state_input_h = Input(shape=(256,), name='inf_input1')
decoder_state_input_c = Input(shape=(256,)) decoder_state_input_c = Input(shape=(256,), name='inf_input2')
decoder_input_states = [decoder_state_input_h, decoder_state_input_c] decoder_input_states = [decoder_state_input_h, decoder_state_input_c]
decoder_out, decoder_h, decoder_c = decoder_LSTM(embed_2, initial_state=decoder_input_states) decoder_out, decoder_h, decoder_c = decoder_LSTM(embed_2, initial_state=decoder_input_states)
...@@ -49,10 +54,14 @@ decoder_out = decoder_dense(decoder_out) ...@@ -49,10 +54,14 @@ decoder_out = decoder_dense(decoder_out)
decoder_model_inf = Model(inputs=[decoder_input] + decoder_input_states, decoder_model_inf = Model(inputs=[decoder_input] + decoder_input_states,
outputs=[decoder_out] + decoder_states ) outputs=[decoder_out] + decoder_states )
encoder_model_inf.summary()
decoder_model_inf.summary()
def decode_seq(inp_seq): def decode_seq(inp_seq):
states_val = encoder_model_inf.predict(inp_seq) states_val = encoder_model_inf.predict(inp_seq)
target_seq = np.zeros((1, target_max_sequence)) target_seq = np.zeros((1, S2S_model.get_layer('input_2').output_shape[1]))
target_seq[0, 0] = target_vocab['sos'] target_seq[0, 0] = target_vocab['sos']
translated_sent = [] translated_sent = []
...@@ -65,27 +74,33 @@ def decode_seq(inp_seq): ...@@ -65,27 +74,33 @@ def decode_seq(inp_seq):
try: try:
sampled_fra_char = target_index_to_word_dict[max_val_index] sampled_fra_char = target_index_to_word_dict[max_val_index]
except KeyError: except KeyError:
# stop_condition = True
sampled_fra_char = 'eos' sampled_fra_char = 'eos'
translated_sent.append(sampled_fra_char) translated_sent.append(sampled_fra_char)
translated_index.append(max_val_index) translated_index.append(max_val_index)
if ((sampled_fra_char == 'eos') or (len(translated_sent) > target_max_sequence)): if ((sampled_fra_char == 'eos') or (len(translated_sent) > S2S_model.get_layer('input_2').output_shape[1])):
stop_condition = True stop_condition = True
target_seq = np.zeros((1, target_max_sequence)) target_seq = np.zeros((1, S2S_model.get_layer('input_2').output_shape[1]))
target_seq[0, 0] = max_val_index target_seq[0, 0] = max_val_index
states_val = [decoder_h, decoder_c] states_val = [decoder_h, decoder_c]
return translated_sent[:-1], translated_index[:-1] return translated_sent[:-1], translated_index[:-1]
y_true = []
y_pred = []
for seq_index in tqdm.tqdm(range(len(source_val))): # for seq_index in range(len(source_corpus)):
# for seq_index in range(10): for seq_index in range(10):
inp_seq = source_val[seq_index:seq_index+1] inp_seq = source_corpus[seq_index:seq_index + 1]
inp_seq = s2s_source_tokenizer.texts_to_sequences(inp_seq)
inp_seq = pad_sequences(inp_seq, maxlen=S2S_model.get_layer('input_1').output_shape[1], padding='post')
translated_sent, translated_index= decode_seq(inp_seq) translated_sent, translated_index= decode_seq(inp_seq)
target_seq = target_corpus[seq_index:seq_index + 1]
target_seq = s2s_target_tokenizer.texts_to_sequences(target_seq)
# PREDICT ICD10 # PREDICT ICD10
source_word_sequence = kerasTokenizer.texts_to_sequences([" ".join(translated_sent)]) source_word_sequence = kerasTokenizer.texts_to_sequences([" ".join(translated_sent)])
word_sequence = pad_sequences(source_word_sequence, maxlen=icd10_model.layers[0].input_shape[1], padding='post') word_sequence = pad_sequences(source_word_sequence, maxlen=icd10_model.layers[0].input_shape[1], padding='post')
...@@ -93,23 +108,23 @@ for seq_index in tqdm.tqdm(range(len(source_val))): ...@@ -93,23 +108,23 @@ for seq_index in tqdm.tqdm(range(len(source_val))):
# print(icd10_code_index, type(icd10_code_index)) # print(icd10_code_index, type(icd10_code_index))
max_val_index = np.argmax(icd10_code_index, axis=1)[0] max_val_index = np.argmax(icd10_code_index, axis=1)[0]
# print(max_val_index) # print(max_val_index)
icd10_label = encoded_Y.inverse_transform(max_val_index) icd10_label = icd10Encoder.inverse_transform(max_val_index)
# print('-') print('-')
# target_index = np.trim_zeros(target_val[seq_index], 'b')[1:-1] target_index = target_seq[0]
# print('Target indexes:', target_index) print('Target indexes:', target_index)
# print('Decoded indexes:', translated_index) print('Decoded indexes:', translated_index)
#
# print('Target sentence:', " ".join([target_index_to_word_dict[x] for x in target_index])) print('Target sentence:', " ".join([target_index_to_word_dict[x] for x in target_index]))
# print('Decoded sentence:', " ".join(translated_sent)) print('Decoded sentence:', " ".join(translated_sent))
#
# print('Target ICD-10:', labels_val[seq_index]) print('Target ICD-10:', labels[seq_index])
# print('Predict ICD-10:', icd10_label) print('Predict ICD-10:', icd10_label)
y_true.append(labels_val[seq_index]) y_true.append(labels[seq_index])
y_pred.append(icd10_label) y_pred.append(icd10_label)
report = classification_report(y_true, y_pred) report = classification_report(y_true, y_pred)
report_df = report_to_df(report) report_df = report_to_df(report)
report_df.to_csv('logs/classification_report.csv') report_df.to_csv('logs/classification_report_test.csv')
print(report_df) print(report_df)
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment