def eval(model, data_eval, voc_size, epoch): model.eval() ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] smm_record = [] med_cnt, visit_cnt = 0, 0 for step, input in enumerate(data_eval): y_gt = [] y_pred = [] y_pred_prob = [] y_pred_label = [] for adm_index, adm in enumerate(input): output_logits = model(adm) y_gt_tmp = np.zeros(voc_size[2]) y_gt_tmp[adm[2]] = 1 y_gt.append(y_gt_tmp) # prediction prod output_logits = output_logits.detach().cpu().numpy() # prediction med set out_list, sorted_predict = sequence_output_process( output_logits, [voc_size[2], voc_size[2] + 1]) y_pred_label.append(sorted(sorted_predict)) y_pred_prob.append(np.mean(output_logits[:, :-2], axis=0)) # prediction label y_pred_tmp = np.zeros(voc_size[2]) y_pred_tmp[out_list] = 1 y_pred.append(y_pred_tmp) visit_cnt += 1 med_cnt += len(sorted_predict) smm_record.append(y_pred_label) adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = \ sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label)) ja.append(adm_ja) prauc.append(adm_prauc) avg_p.append(adm_avg_p) avg_r.append(adm_avg_r) avg_f1.append(adm_avg_f1) llprint('\rtest step: {} / {}'.format(step, len(data_eval))) # ddi rate ddi_rate = ddi_rate_score(smm_record, path='../data/output/ddi_A_final.pkl') llprint( '\nDDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n' .format(ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt)) return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean( avg_r), np.mean(avg_f1), med_cnt / visit_cnt
def eval(model, data_eval, voc_size, epoch): # evaluate print('') model.eval() ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] records = [] med_cnt = 0 visit_cnt = 0 for step, input in enumerate(data_eval): y_gt = [] y_pred = [] y_pred_prob = [] y_pred_label = [] for adm in input: y_gt_tmp = np.zeros(voc_size[2]) y_gt_tmp[adm[2]] = 1 y_gt.append(y_gt_tmp) output_logits = model(adm) output_logits = output_logits.detach().cpu().numpy() out_list, sorted_predict = sequence_output_process( output_logits, [voc_size[2], voc_size[2] + 1]) y_pred_label.append(sorted(sorted_predict)) y_pred_prob.append(np.mean(output_logits[:, :-2], axis=0)) y_pred_tmp = np.zeros(voc_size[2]) y_pred_tmp[out_list] = 1 y_pred.append(y_pred_tmp) visit_cnt += 1 med_cnt += len(sorted_predict) records.append(y_pred_label) adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = sequence_metric( np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label)) ja.append(adm_ja) prauc.append(adm_prauc) avg_p.append(adm_avg_p) avg_r.append(adm_avg_r) avg_f1.append(adm_avg_f1) llprint('\rEval--Epoch: %d, Step: %d/%d' % (epoch, step, len(data_eval))) # ddi rate ddi_rate = ddi_rate_score(records) llprint( '\tDDI Rate: %.4f, Jaccard: %.4f, PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f\n' % (ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1))) print('avg med', med_cnt / visit_cnt) return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean( avg_r), np.mean(avg_f1)
def fine_tune(fine_tune_name=''): # load data data_path = '../data/output/records_final.pkl' voc_path = '../data/output/voc_final.pkl' device = torch.device('cpu:0') data = dill.load(open(data_path, 'rb')) voc = dill.load(open(voc_path, 'rb')) diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[ 'med_voc'] ddi_A = dill.load(open('../data/output/ddi_A_final.pkl', 'rb')) split_point = int(len(data) * 2 / 3) data_train = data[:split_point] eval_len = int(len(data[split_point:]) / 2) data_test = data[split_point:split_point + eval_len] # data_eval = data[split_point+eval_len:] voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word)) model = Leap(voc_size, device=device) model.load_state_dict( torch.load( open(os.path.join("saved", args.model_name, fine_tune_name), 'rb'))) model.to(device) END_TOKEN = voc_size[2] + 1 optimizer = Adam(model.parameters(), lr=args.lr) ddi_rate_record = [] EPOCH = 100 for epoch in range(EPOCH): loss_record = [] start_time = time.time() random_train_set = [ random.choice(data_train) for i in range(len(data_train)) ] for step, input in enumerate(random_train_set): model.train() K_flag = False for adm in input: target = adm[2] output_logits = model(adm) out_list, sorted_predict = sequence_output_process( output_logits.detach().cpu().numpy(), [voc_size[2], voc_size[2] + 1]) inter = set(out_list) & set(target) union = set(out_list) | set(target) jaccard = 0 if union == 0 else len(inter) / len(union) K = 0 for i in out_list: if K == 1: K_flag = True break for j in out_list: if ddi_A[i][j] == 1: K = 1 break loss = -jaccard * K * torch.mean( F.log_softmax(output_logits, dim=-1)) loss_record.append(loss.item()) optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() llprint('\rtraining step: {} / {}'.format(step, len(random_train_set))) if K_flag: print() ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval( model, data_test, voc_size, epoch) # test torch.save( model.state_dict(), open(os.path.join('saved', args.model_name, 'final.model'), 'wb'))
def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01): start = time.time() print_loss_total = 0 # Reset every print_every encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate) decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate) training_pairs = [tensorsFromPair(random.choice(train_pairs)) for i in range(n_iters)] criterion = nn.CrossEntropyLoss() history = defaultdict(list) for epoch in range(30): for iter in range(1, n_iters + 1): training_pair = training_pairs[iter - 1] input_tensor = training_pair[0] target_tensor = training_pair[1] loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion) print_loss_total += loss llprint('\rTrain--Epoch: %d, Step: %d/%d' % (epoch, iter, n_iters)) print_loss_avg = print_loss_total / n_iters print_loss_total = 0 #eval y_gt = [] y_pred = [] y_pred_prob = [] y_pred_label = [] for pair in eval_pairs: y_gt_tmp = np.zeros(len(med_voc.idx2word)) y_gt_tmp[np.array(pair[1])[:-1]-2] = 1 y_gt.append(y_gt_tmp) input_tensor, output_tensor = tensorsFromPair(pair) output_logits = evaluate(encoder, decoder, input_tensor) output_logits = F.softmax(output_logits) output_logits = output_logits.detach().cpu().numpy() out_list, sorted_predict = sequence_output_process(output_logits, [SOS_token, EOS_token]) y_pred_label.append(np.array(sorted_predict)-2) y_pred_prob.append(np.mean(output_logits[:, 2:], axis=0)) y_pred_tmp = np.zeros(len(med_voc.idx2word)) if len(out_list) != 0 : y_pred_tmp[np.array(out_list) - 2] = 1 y_pred.append(y_pred_tmp) ja, prauc, avg_p, avg_r, avg_f1 = sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label)) # ddi rate ddi_A = dill.load(open('../data/ddi_A_final.pkl', 'rb')) all_cnt = 0 dd_cnt = 0 for adm in y_pred_label: med_code_set = adm for i, med_i in enumerate(med_code_set): for j, med_j in enumerate(med_code_set): if j <= i: continue all_cnt += 1 if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1: dd_cnt += 1 ddi_rate = dd_cnt / all_cnt history['ja'].append(ja) history['ddi_rate'].append(ddi_rate) history['avg_p'].append(avg_p) history['avg_r'].append(avg_r) history['avg_f1'].append(avg_f1) history['prauc'].append(prauc) llprint('\n\tDDI Rate: %.4f, Jaccard: %.4f, PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f\n' % ( ddi_rate, ja, prauc, avg_p, avg_r, avg_f1 )) dill.dump(history, open(os.path.join('saved', model_name, 'history.pkl'), 'wb')) torch.save(encoder.state_dict(), open( os.path.join('saved', model_name, 'encoder_Epoch_%d_JA_%.4f_DDI_%.4f.model' % (epoch, ja, dd_cnt/all_cnt)), 'wb')) torch.save(decoder.state_dict(), open( os.path.join('saved', model_name, 'decoder_Epoch_%d_JA_%.4f_DDI_%.4f.model' % (epoch, ja, dd_cnt/all_cnt)), 'wb'))
MAX_LEN = len(input_seq) output_seq = list(np.array(o) + 2) output_seq.append(EOS_token) test_pairs.append((input_seq, output_seq)) for pair in test_pairs: y_gt_tmp = np.zeros(len(med_voc.idx2word)) y_gt_tmp[np.array(pair[1])[:-1] - 2] = 1 y_gt.append(y_gt_tmp) input_tensor, output_tensor = tensorsFromPair(pair) output_logits = evaluate(encoder1, decoder1, input_tensor) output_logits = F.softmax(output_logits) output_logits = output_logits.detach().cpu().numpy() out_list, sorted_predict = sequence_output_process(output_logits, [SOS_token, EOS_token]) y_pred_label.append(np.array(sorted_predict) - 2) y_pred_prob.append(np.mean(output_logits[:, 2:], axis=0)) y_pred_tmp = np.zeros(len(med_voc.idx2word)) if len(out_list) != 0: y_pred_tmp[np.array(out_list) - 2] = 1 y_pred.append(y_pred_tmp) ja, prauc, avg_p, avg_r, avg_f1 = sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label)) # ddi rate ddi_A = dill.load(open('../data/ddi_A_final.pkl', 'rb')) all_cnt = 0
def eval(model, data_eval, voc_size, epoch): model.eval() ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] smm_record = [] med_cnt, visit_cnt = 0, 0 add_list, delete_list = [], [] for step, input in enumerate(data_eval): y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], [] if len(input) < 2: continue add_temp_list, delete_temp_list = [], [] for adm_idx, adm in enumerate(input): if adm_idx == 0: previous_set = adm[2] continue output_logits = model(adm) y_gt_tmp = np.zeros(voc_size[2]) y_gt_tmp[adm[2]] = 1 y_gt.append(y_gt_tmp) # prediction prod output_logits = output_logits.detach().cpu().numpy() # prediction med set out_list, sorted_predict = sequence_output_process(output_logits, [voc_size[2], voc_size[2]+1]) y_pred_label.append(sorted(sorted_predict)) y_pred_prob.append(np.mean(output_logits[:, :-2], axis=0)) # prediction label y_pred_tmp = np.zeros(voc_size[2]) y_pred_tmp[out_list] = 1 y_pred.append(y_pred_tmp) visit_cnt += 1 med_cnt += len(sorted_predict) #### add or delete add_gt = set(np.where(y_gt_tmp == 1)[0]) - set(previous_set) delete_gt = set(previous_set) - set(np.where(y_gt_tmp == 1)[0]) add_pre = set(np.where(y_pred_tmp == 1)[0]) - set(previous_set) delete_pre = set(previous_set) - set(np.where(y_pred_tmp == 1)[0]) add_distance = len(set(add_pre) - set(add_gt)) + len(set(add_gt) - set(add_pre)) delete_distance = len(set(delete_pre) - set(delete_gt)) + len(set(delete_gt) - set(delete_pre)) #### add_temp_list.append(add_distance) delete_temp_list.append(delete_distance) previous_temp_set = out_list if len(add_temp_list) > 1: add_list.append(np.mean(add_temp_list)) delete_list.append(np.mean(delete_temp_list)) else: add_list.append(add_temp_list[0]) delete_list.append(delete_temp_list[0]) smm_record.append(y_pred_label) adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = \ sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label)) ja.append(adm_ja) prauc.append(adm_prauc) avg_p.append(adm_avg_p) avg_r.append(adm_avg_r) avg_f1.append(adm_avg_f1) llprint('\rtest step: {} / {}'.format(step, len(data_eval))) # ddi rate ddi_rate = ddi_rate_score(smm_record, path='../data/output/ddi_A_final.pkl') llprint('\nDDI Rate: {:.4}, Jaccard: {:.4}, AVG_F1: {:.4}, Add: {:.4}, Delete: {:.4}, AVG_MED: {:.4}\n'.format( np.float(ddi_rate), np.mean(ja), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt )) return np.float(ddi_rate), np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt
def fine_tune(fine_tune_name=''): data_path = '../../data/records_final.pkl' voc_path = '../../data/voc_final.pkl' device = torch.device('cuda:0') data = dill.load(open(data_path, 'rb')) voc = dill.load(open(voc_path, 'rb')) diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[ 'med_voc'] ddi_A = dill.load(open('../../data/ddi_A_final.pkl', 'rb')) split_point = int(len(data) * 2 / 3) data_train = data[:split_point] eval_len = int(len(data[split_point:]) / 2) data_test = data[split_point:split_point + eval_len] # data_eval = data[split_point+eval_len:] voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word)) model = Leap(voc_size, device=device) model.load_state_dict( torch.load( open(os.path.join("saved", model_name, fine_tune_name), 'rb'))) model.to(device) EPOCH = 30 LR = 0.0001 END_TOKEN = voc_size[2] + 1 optimizer = Adam(model.parameters(), lr=LR) ddi_rate_record = [] for epoch in range(1): loss_record = [] start_time = time.time() random_train_set = [ random.choice(data_train) for i in range(len(data_train)) ] for step, input in enumerate(random_train_set): model.train() K_flag = False for adm in input: target = adm[2] output_logits = model(adm) out_list, sorted_predict = sequence_output_process( output_logits.detach().cpu().numpy(), [voc_size[2], voc_size[2] + 1]) inter = set(out_list) & set(target) union = set(out_list) | set(target) jaccard = 0 if union == 0 else len(inter) / len(union) K = 0 for i in out_list: if K == 1: K_flag = True break for j in out_list: if ddi_A[i][j] == 1: K = 1 break loss = -jaccard * K * torch.mean( F.log_softmax(output_logits, dim=-1)) loss_record.append(loss.item()) optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() llprint('\rTrain--Epoch: %d, Step: %d/%d' % (epoch, step, len(data_train))) if K_flag: ddi_rate, ja, prauc, avg_p, avg_r, avg_f1 = eval( model, data_test, voc_size, epoch) end_time = time.time() elapsed_time = (end_time - start_time) / 60 llprint( '\tEpoch: %d, Loss1: %.4f, One Epoch Time: %.2fm, Appro Left Time: %.2fh\n' % (epoch, np.mean(loss_record), elapsed_time, elapsed_time * (EPOCH - epoch - 1) / 60)) torch.save( model.state_dict(), open( os.path.join( 'saved', model_name, 'fine_Epoch_%d_JA_%.4f_DDI_%.4f.model' % (epoch, ja, ddi_rate)), 'wb')) print('') # test torch.save(model.state_dict(), open(os.path.join('saved', model_name, 'final.model'), 'wb'))