def eval(model, data_eval, voc_size, epoch): # evaluate print('') model.eval() smm_record = [] ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] case_study = defaultdict(dict) med_cnt = 0 visit_cnt = 0 for step, input in enumerate(data_eval): if len(input) < 2: # visit > 2 continue y_gt = [] y_pred = [] y_pred_prob = [] y_pred_label = [] for i in range(1, len(input)): y_pred_label_tmp = [] y_gt_tmp = np.zeros(voc_size[2]) y_gt_tmp[input[i][2]] = 1 y_gt.append(y_gt_tmp) target_output1 = model(input[:i]) target_output1 = F.sigmoid(target_output1).detach().cpu().numpy()[0] y_pred_prob.append(target_output1) y_pred_tmp = target_output1.copy() y_pred_tmp[y_pred_tmp >= 0.3] = 1 y_pred_tmp[y_pred_tmp < 0.3] = 0 y_pred.append(y_pred_tmp) for idx, value in enumerate(y_pred_tmp): if value == 1: y_pred_label_tmp.append(idx) y_pred_label.append(y_pred_label_tmp) med_cnt += len(y_pred_label_tmp) visit_cnt += 1 smm_record.append(y_pred_label) adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob)) case_study[adm_ja] = {'ja': adm_ja, 'patient':input, 'y_label':y_pred_label} ja.append(adm_ja) prauc.append(adm_prauc) avg_p.append(adm_avg_p) avg_r.append(adm_avg_r) avg_f1.append(adm_avg_f1) llprint('\rEval--Epoch: %d, Step: %d/%d' % (epoch, step, len(data_eval))) dill.dump(case_study, open(os.path.join('saved', model_name, 'case_study.pkl'), 'wb')) # ddi rate ddi_rate = ddi_rate_score(smm_record) llprint('\tDDI Rate: %.4f, Jaccard: %.4f, PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f\n' % ( ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1) )) print('avg med', med_cnt / visit_cnt) return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1)
def eval(model, data_eval, voc_size, epoch): model.eval() ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] smm_record = [] med_cnt, visit_cnt = 0, 0 for step, input in enumerate(data_eval): y_gt = [] y_pred = [] y_pred_prob = [] y_pred_label = [] for adm_index, adm in enumerate(input): output_logits = model(adm) y_gt_tmp = np.zeros(voc_size[2]) y_gt_tmp[adm[2]] = 1 y_gt.append(y_gt_tmp) # prediction prod output_logits = output_logits.detach().cpu().numpy() # prediction med set out_list, sorted_predict = sequence_output_process( output_logits, [voc_size[2], voc_size[2] + 1]) y_pred_label.append(sorted(sorted_predict)) y_pred_prob.append(np.mean(output_logits[:, :-2], axis=0)) # prediction label y_pred_tmp = np.zeros(voc_size[2]) y_pred_tmp[out_list] = 1 y_pred.append(y_pred_tmp) visit_cnt += 1 med_cnt += len(sorted_predict) smm_record.append(y_pred_label) adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = \ sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label)) ja.append(adm_ja) prauc.append(adm_prauc) avg_p.append(adm_avg_p) avg_r.append(adm_avg_r) avg_f1.append(adm_avg_f1) llprint('\rtest step: {} / {}'.format(step, len(data_eval))) # ddi rate ddi_rate = ddi_rate_score(smm_record, path='../data/output/ddi_A_final.pkl') llprint( '\nDDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n' .format(ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt)) return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean( avg_r), np.mean(avg_f1), med_cnt / visit_cnt
def eval(model, data_eval, voc_size, epoch): # evaluate print('') model.eval() smm_record = [] ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] for step, input in enumerate(data_eval): y_gt = [] y_pred = [] y_pred_prob = [] y_pred_label = [] input1_hidden, input2_hidden, target_hidden = None, None, None prev_target = None for adm_idx, adm in enumerate(input): target_output1, [input1_hidden, input2_hidden, target_hidden] = model( adm, prev_target, [input1_hidden, input2_hidden, target_hidden]) prev_target = adm[2] y_pred_label_tmp = [] y_gt_tmp = np.zeros(voc_size[2]) y_gt_tmp[adm[2]] = 1 y_gt.append(y_gt_tmp) target_output1 = F.sigmoid( target_output1).detach().cpu().numpy()[0] y_pred_prob.append(target_output1) y_pred_tmp = target_output1.copy() y_pred_tmp[y_pred_tmp >= 0.5] = 1 y_pred_tmp[y_pred_tmp < 0.5] = 0 y_pred.append(y_pred_tmp) for idx, value in enumerate(y_pred_tmp): if value == 1: y_pred_label_tmp.append(idx) y_pred_label.append(y_pred_label_tmp) smm_record.append(y_pred_label) adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric( np.array(y_gt), np.array(y_pred), np.array(y_pred_prob)) ja.append(adm_ja) prauc.append(adm_prauc) avg_p.append(adm_avg_p) avg_r.append(adm_avg_r) avg_f1.append(adm_avg_f1) llprint('\rEval--Epoch: %d, Step: %d/%d' % (epoch, step, len(data_eval))) # ddi rate ddi_rate = ddi_rate_score(smm_record) llprint( '\tDDI Rate: %.4f, Jaccard: %.4f, PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f\n' % (ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1))) return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean( avg_r), np.mean(avg_f1)
def eval(model, data_eval, voc_size, epoch): # evaluate print('') model.eval() ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] records = [] med_cnt, visit_cnt = 0, 0 for step, input in enumerate(data_eval): y_gt = [] y_pred = [] y_pred_prob = [] y_pred_label = [] i1_state, i2_state, i3_state = None, None, None for adm in input: y_gt_tmp = np.zeros(voc_size[2]) y_gt_tmp[adm[2]] = 1 y_gt.append(y_gt_tmp) output_logits, i1_state, i2_state, i3_state = model( adm, i1_state, i2_state, i3_state) output_logits = output_logits.detach().cpu().numpy() out_list, sorted_predict = sequence_output_process( output_logits, [voc_size[2], voc_size[2] + 1]) y_pred_label.append(sorted_predict) y_pred_prob.append(np.mean(output_logits[:, :-2], axis=0)) y_pred_tmp = np.zeros(voc_size[2]) y_pred_tmp[out_list] = 1 y_pred.append(y_pred_tmp) visit_cnt += 1 med_cnt += len(y_pred_tmp) records.append(y_pred_label) adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = sequence_metric( np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label)) ja.append(adm_ja) prauc.append(adm_prauc) avg_p.append(adm_avg_p) avg_r.append(adm_avg_r) avg_f1.append(adm_avg_f1) llprint('\rEval--Epoch: %d, Step: %d/%d' % (epoch, step, len(data_eval))) # ddi rate ddi_rate = ddi_rate_score(records) llprint( '\tDDI Rate: %.4f, Jaccard: %.4f, PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f, AVG_Med: %.4f\n' % (ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt)) return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean( avg_r), np.mean(avg_f1)
def eval(model, data_eval, voc_size, epoch): model.eval() smm_record = [] ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] med_cnt, visit_cnt = 0, 0 for step, input in enumerate(data_eval): y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], [] if len(input) < 2: continue for i in range(1, len(input)): target_output = model(input[:i]) y_gt_tmp = np.zeros(voc_size[2]) y_gt_tmp[input[i][2]] = 1 y_gt.append(y_gt_tmp) # prediction prob target_output = F.sigmoid(target_output).detach().cpu().numpy()[0] y_pred_prob.append(target_output) # prediction med set y_pred_tmp = target_output.copy() y_pred_tmp[y_pred_tmp >= 0.4] = 1 y_pred_tmp[y_pred_tmp < 0.4] = 0 y_pred.append(y_pred_tmp) # prediction label y_pred_label_tmp = np.where(y_pred_tmp == 1)[0] y_pred_label.append(y_pred_label_tmp) med_cnt += len(y_pred_label_tmp) visit_cnt += 1 smm_record.append(y_pred_label) adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 =\ multi_label_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob)) ja.append(adm_ja) prauc.append(adm_prauc) avg_p.append(adm_avg_p) avg_r.append(adm_avg_r) avg_f1.append(adm_avg_f1) llprint('\rtest step: {} / {}'.format(step, len(data_eval))) # ddi rate ddi_rate = ddi_rate_score(smm_record, path='../data/output/ddi_A_final.pkl') llprint( '\nDDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n' .format(ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt)) return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean( avg_r), np.mean(avg_f1), med_cnt / visit_cnt
def eval(model, data_eval, voc_size, epoch, val=0, threshold1=0.3, threshold2=0.3): model.eval() smm_record = [] ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] med_cnt, visit_cnt = 0, 0 label_list, prob_add, prob_delete = [], [], [] add_list, delete_list = [], [] for step, input in enumerate(data_eval): y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], [] add_temp_list, delete_temp_list = [], [] if len(input) < 2: continue for adm_idx, adm in enumerate(input): if adm_idx == 0: y_old = np.zeros(voc_size[2]) y_old[adm[2]] = 1 continue y_gt_tmp = np.zeros(voc_size[2]) y_gt_tmp[adm[2]] = 1 y_gt.append(y_gt_tmp) label_list.append(y_gt_tmp) add_result, delete_result = model(input[:adm_idx + 1]) # prediction prod y_pred_tmp_add = F.sigmoid(add_result).detach().cpu().numpy()[0] y_pred_tmp_delete = F.sigmoid( delete_result).detach().cpu().numpy()[0] y_pred_prob.append(y_pred_tmp_add) prob_add.append(y_pred_tmp_add) prob_delete.append(y_pred_tmp_delete) previous_set = np.where(y_old == 1)[0] # prediction med set y_old[y_pred_tmp_add >= threshold2] = 1 y_old[y_pred_tmp_delete >= threshold1] = 0 y_pred.append(y_old) # prediction label y_pred_label_tmp = np.where(y_old == 1)[0] y_pred_label.append(sorted(y_pred_label_tmp)) visit_cnt += 1 med_cnt += len(y_pred_label_tmp) #### add or delete add_gt = set(np.where(y_gt_tmp == 1)[0]) - set(previous_set) delete_gt = set(previous_set) - set(np.where(y_gt_tmp == 1)[0]) add_pre = set(np.where(y_old == 1)[0]) - set(previous_set) delete_pre = set(previous_set) - set(np.where(y_old == 1)[0]) add_distance = len(set(add_pre) - set(add_gt)) + len(set(add_gt) - set(add_pre)) delete_distance = len(set(delete_pre) - set(delete_gt)) + len( set(delete_gt) - set(delete_pre)) #### add_temp_list.append(add_distance) delete_temp_list.append(delete_distance) if len(add_temp_list) > 1: add_list.append(np.mean(add_temp_list)) delete_list.append(np.mean(delete_temp_list)) elif len(add_temp_list) == 1: add_list.append(add_temp_list[0]) delete_list.append(delete_temp_list[0]) smm_record.append(y_pred_label) adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric( np.array(y_gt), np.array(y_pred), np.array(y_pred_prob)) ja.append(adm_ja) prauc.append(adm_prauc) avg_p.append(adm_avg_p) avg_r.append(adm_avg_r) avg_f1.append(adm_avg_f1) llprint('\rtest step: {} / {}'.format(step, len(data_eval))) # ddi rate ddi_rate = ddi_rate_score(smm_record, path='../data/output/ddi_A_final.pkl') # llprint('\nDDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, Add: {:.4}, Delete; {:.4}, AVG_MED: {:.4}\n'.format( # ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt # )) # print ('-1-', ddi_rate, '-2-', np.mean(ja), '-3-', np.mean(prauc), '-4-', np.mean(avg_f1), '-5-', np.mean(add_list), '-6-', np.mean(delete_list), '-7-', med_cnt / visit_cnt) llprint( '\nDDI Rate: {:.4}, Jaccard: {:.4}, AVG_F1: {:.4}, Add: {:.4}, Delete: {:.4}, AVG_MED: {:.4}\n' .format(np.float(ddi_rate), np.mean(ja), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt)) if val == 0: return np.float(ddi_rate), np.mean(ja), np.mean(prauc), np.mean( avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean( add_list), np.mean(delete_list), med_cnt / visit_cnt else: return np.array(label_list), np.array(prob_add), np.array(prob_delete)
def eval(model, data_eval, voc_size, epoch, val=0, threshold1=0.8, threshold2=0.2): model.eval() smm_record = [] ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] med_cnt, visit_cnt = 0, 0 add_list, delete_list = [], [] for step, input in enumerate(data_eval): y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], [] add_temp_list, delete_temp_list = [], [] if len(input) < 2: continue for adm_idx, adm in enumerate(input): if adm_idx == 0: y_old = np.zeros(voc_size[2]) y_old[adm[2]] = 1 continue y_gt_tmp = np.zeros(voc_size[2]) y_gt_tmp[adm[2]] = 1 y_gt.append(y_gt_tmp) result_out = model(input[:adm_idx + 1]) # prediction prod y_pred_tmp = F.sigmoid( result_out[:, 0]).detach().cpu().numpy().tolist() y_pred_prob.append(y_pred_tmp) previous_set = np.where(y_old == 1)[0] # prediction med set # result = F.sigmoid(result).detach().cpu().numpy()[0] assignment = torch.max(result_out, axis=1)[1].cpu().numpy() y_old[assignment == 1] = 1 y_old[assignment == 2] = 0 y_pred.append(y_old) # prediction label y_pred_label_tmp = np.where(y_old == 1)[0] y_pred_label.append(sorted(y_pred_label_tmp)) visit_cnt += 1 med_cnt += len(y_pred_label_tmp) #### add or delete add_gt = set(np.where(y_gt_tmp == 1)[0]) - set(previous_set) delete_gt = set(previous_set) - set(np.where(y_gt_tmp == 1)[0]) add_pre = set(np.where(y_old == 1)[0]) - set(previous_set) delete_pre = set(previous_set) - set(np.where(y_old == 1)[0]) add_distance = len(set(add_pre) - set(add_gt)) + len(set(add_gt) - set(add_pre)) delete_distance = len(set(delete_pre) - set(delete_gt)) + len( set(delete_gt) - set(delete_pre)) #### add_temp_list.append(add_distance) delete_temp_list.append(delete_distance) if len(add_temp_list) > 1: add_list.append(np.mean(add_temp_list)) delete_list.append(np.mean(delete_temp_list)) elif len(add_temp_list) == 1: add_list.append(add_temp_list[0]) delete_list.append(delete_temp_list[0]) smm_record.append(y_pred_label) adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric( np.array(y_gt), np.array(y_pred), np.array(y_pred_prob)) ja.append(adm_ja) prauc.append(adm_prauc) avg_p.append(adm_avg_p) avg_r.append(adm_avg_r) avg_f1.append(adm_avg_f1) llprint('\rtest step: {} / {}'.format(step, len(data_eval))) # ddi rate ddi_rate = ddi_rate_score(smm_record, path='../data/output/ddi_A_final.pkl') llprint( '\nDDI Rate: {:.4}, Jaccard: {:.4}, AVG_F1: {:.4}, Add: {:.4}, Delete: {:.4}, AVG_MED: {:.4}\n' .format(ddi_rate, np.mean(ja), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt)) return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean( avg_r), np.mean(avg_f1), np.mean(add_list), np.mean( delete_list), med_cnt / visit_cnt
def main(): if not os.path.exists(os.path.join("saved", model_name)): os.makedirs(os.path.join("saved", model_name)) data_path = '../data/records_final.pkl' voc_path = '../data/voc_final.pkl' ehr_adj_path = '../data/ehr_adj_final.pkl' ddi_adj_path = '../data/ddi_A_final.pkl' device = torch.device('cuda:0') ehr_adj = dill.load(open(ehr_adj_path, 'rb')) ddi_adj = dill.load(open(ddi_adj_path, 'rb')) data = dill.load(open(data_path, 'rb')) voc = dill.load(open(voc_path, 'rb')) diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[ 'med_voc'] split_point = int(len(data) * 2 / 3) data_train = data[:split_point] eval_len = int(len(data[split_point:]) / 2) data_test = data[split_point:split_point + eval_len] data_eval = data[split_point + eval_len:] EPOCH = 40 LR = 0.0002 TEST = args.eval Neg_Loss = args.ddi DDI_IN_MEM = args.ddi TARGET_DDI = 0.05 T = 0.5 decay_weight = 0.85 voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word)) model = GAMENet(voc_size, ehr_adj, ddi_adj, emb_dim=64, device=device, ddi_in_memory=DDI_IN_MEM) if TEST: model.load_state_dict(torch.load(open(resume_name, 'rb'))) model.to(device=device) print('parameters', get_n_params(model)) optimizer = Adam(list(model.parameters()), lr=LR) if TEST: eval(model, data_test, voc_size, 0) else: history = defaultdict(list) best_epoch = 0 best_ja = 0 for epoch in range(EPOCH): loss_record1 = [] start_time = time.time() model.train() prediction_loss_cnt = 0 neg_loss_cnt = 0 for step, input in enumerate(data_train): for idx, adm in enumerate(input): seq_input = input[:idx + 1] loss1_target = np.zeros((1, voc_size[2])) loss1_target[:, adm[2]] = 1 loss3_target = np.full((1, voc_size[2]), -1) for idx, item in enumerate(adm[2]): loss3_target[0][idx] = item target_output1, batch_neg_loss = model(seq_input) loss1 = F.binary_cross_entropy_with_logits( target_output1, torch.FloatTensor(loss1_target).to(device)) loss3 = F.multilabel_margin_loss( F.sigmoid(target_output1), torch.LongTensor(loss3_target).to(device)) if Neg_Loss: target_output1 = F.sigmoid( target_output1).detach().cpu().numpy()[0] target_output1[target_output1 >= 0.5] = 1 target_output1[target_output1 < 0.5] = 0 y_label = np.where(target_output1 == 1)[0] current_ddi_rate = ddi_rate_score([[y_label]]) if current_ddi_rate <= TARGET_DDI: loss = 0.9 * loss1 + 0.01 * loss3 prediction_loss_cnt += 1 else: rnd = np.exp((TARGET_DDI - current_ddi_rate) / T) if np.random.rand(1) < rnd: loss = batch_neg_loss neg_loss_cnt += 1 else: loss = 0.9 * loss1 + 0.01 * loss3 prediction_loss_cnt += 1 else: loss = 0.9 * loss1 + 0.01 * loss3 optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() loss_record1.append(loss.item()) llprint( '\rTrain--Epoch: %d, Step: %d/%d, L_p cnt: %d, L_neg cnt: %d' % (epoch, step, len(data_train), prediction_loss_cnt, neg_loss_cnt)) # annealing T *= decay_weight ddi_rate, ja, prauc, avg_p, avg_r, avg_f1 = eval( model, data_eval, voc_size, epoch) history['ja'].append(ja) history['ddi_rate'].append(ddi_rate) history['avg_p'].append(avg_p) history['avg_r'].append(avg_r) history['avg_f1'].append(avg_f1) history['prauc'].append(prauc) end_time = time.time() elapsed_time = (end_time - start_time) / 60 llprint( '\tEpoch: %d, Loss: %.4f, One Epoch Time: %.2fm, Appro Left Time: %.2fh\n' % (epoch, np.mean(loss_record1), elapsed_time, elapsed_time * (EPOCH - epoch - 1) / 60)) torch.save( model.state_dict(), open( os.path.join( 'saved', model_name, 'Epoch_%d_JA_%.4f_DDI_%.4f.model' % (epoch, ja, ddi_rate)), 'wb')) print('') if epoch != 0 and best_ja < ja: best_epoch = epoch best_ja = ja dill.dump(history, open(os.path.join('saved', model_name, 'history.pkl'), 'wb')) # test torch.save( model.state_dict(), open(os.path.join('saved', model_name, 'final.model'), 'wb')) print('best_epoch:', best_epoch)
def main(): # load data data_path = '../data/output/records_final.pkl' voc_path = '../data/output/voc_final.pkl' ddi_adj_path = '../data/output/ddi_A_final.pkl' ddi_mask_path = '../data/output/ddi_mask_H.pkl' molecule_path = '../data/output/atc3toSMILES.pkl' device = torch.device('cuda:{}'.format(args.cuda)) ddi_adj = dill.load(open(ddi_adj_path, 'rb')) ddi_mask_H = dill.load(open(ddi_mask_path, 'rb')) data = dill.load(open(data_path, 'rb')) molecule = dill.load(open(molecule_path, 'rb')) voc = dill.load(open(voc_path, 'rb')) diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[ 'med_voc'] split_point = int(len(data) * 2 / 3) data_train = data[:split_point] eval_len = int(len(data[split_point:]) / 2) data_test = data[split_point:split_point + eval_len] data_eval = data[split_point + eval_len:] MPNNSet, N_fingerprint, average_projection = buildMPNN( molecule, med_voc.idx2word, 2, device) voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word)) model = SafeDrugModel(voc_size, ddi_adj, ddi_mask_H, MPNNSet, N_fingerprint, average_projection, emb_dim=args.dim, device=device) # model.load_state_dict(torch.load(open(args.resume_path, 'rb'))) if args.Test: model.load_state_dict(torch.load(open(args.resume_path, 'rb'))) model.to(device=device) tic = time.time() ddi_list, ja_list, prauc_list, f1_list, med_list = [], [], [], [], [] # ### # for threshold in np.linspace(0.00, 0.20, 30): # print ('threshold = {}'.format(threshold)) # ddi, ja, prauc, _, _, f1, avg_med = eval(model, data_test, voc_size, 0, threshold) # ddi_list.append(ddi) # ja_list.append(ja) # prauc_list.append(prauc) # f1_list.append(f1) # med_list.append(avg_med) # total = [ddi_list, ja_list, prauc_list, f1_list, med_list] # with open('ablation_ddi.pkl', 'wb') as infile: # dill.dump(total, infile) # ### result = [] for _ in range(10): test_sample = np.random.choice(data_test, round(len(data_test) * 0.8), replace=True) ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval( model, test_sample, voc_size, 0) result.append([ddi_rate, ja, avg_f1, prauc, avg_med]) result = np.array(result) mean = result.mean(axis=0) std = result.std(axis=0) outstring = "" for m, s in zip(mean, std): outstring += "{:.4f} $\pm$ {:.4f} & ".format(m, s) print(outstring) print('test time: {}'.format(time.time() - tic)) return model.to(device=device) # print('parameters', get_n_params(model)) # exit() optimizer = Adam(list(model.parameters()), lr=args.lr) # start iterations history = defaultdict(list) best_epoch, best_ja = 0, 0 EPOCH = 50 for epoch in range(EPOCH): tic = time.time() print('\nepoch {} --------------------------'.format(epoch + 1)) model.train() for step, input in enumerate(data_train): loss = 0 for idx, adm in enumerate(input): seq_input = input[:idx + 1] loss_bce_target = np.zeros((1, voc_size[2])) loss_bce_target[:, adm[2]] = 1 loss_multi_target = np.full((1, voc_size[2]), -1) for idx, item in enumerate(adm[2]): loss_multi_target[0][idx] = item result, loss_ddi = model(seq_input) loss_bce = F.binary_cross_entropy_with_logits( result, torch.FloatTensor(loss_bce_target).to(device)) loss_multi = F.multilabel_margin_loss( F.sigmoid(result), torch.LongTensor(loss_multi_target).to(device)) result = F.sigmoid(result).detach().cpu().numpy()[0] result[result >= 0.5] = 1 result[result < 0.5] = 0 y_label = np.where(result == 1)[0] current_ddi_rate = ddi_rate_score( [[y_label]], path='../data/output/ddi_A_final.pkl') if current_ddi_rate <= args.target_ddi: loss = 0.95 * loss_bce + 0.05 * loss_multi else: beta = min( 0, 1 + (args.target_ddi - current_ddi_rate) / args.kp) loss = beta * (0.95 * loss_bce + 0.05 * loss_multi) + (1 - beta) * loss_ddi optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() llprint('\rtraining step: {} / {}'.format(step, len(data_train))) print() tic2 = time.time() ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval( model, data_eval, voc_size, epoch) print('training time: {}, test time: {}'.format( time.time() - tic, time.time() - tic2)) history['ja'].append(ja) history['ddi_rate'].append(ddi_rate) history['avg_p'].append(avg_p) history['avg_r'].append(avg_r) history['avg_f1'].append(avg_f1) history['prauc'].append(prauc) history['med'].append(avg_med) if epoch >= 5: print('ddi: {}, Med: {}, Ja: {}, F1: {}, PRAUC: {}'.format( np.mean(history['ddi_rate'][-5:]), np.mean(history['med'][-5:]), np.mean(history['ja'][-5:]), np.mean(history['avg_f1'][-5:]), np.mean(history['prauc'][-5:]))) torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \ 'Epoch_{}_TARGET_{:.2}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, args.target_ddi, ja, ddi_rate)), 'wb')) if epoch != 0 and best_ja < ja: best_epoch = epoch best_ja = ja print('best_epoch: {}'.format(best_epoch)) dill.dump( history, open( os.path.join('saved', args.model_name, 'history_{}.pkl'.format(args.model_name)), 'wb'))
def main(): data_path = '../data/output/records_final.pkl' voc_path = '../data/output/voc_final.pkl' ehr_adj_path = '../data/output/ehr_adj_final.pkl' ddi_adj_path = '../data/output/ddi_A_final.pkl' device = torch.device('cuda:{}'.format(args.cuda)) ehr_adj = dill.load(open(ehr_adj_path, 'rb')) ddi_adj = dill.load(open(ddi_adj_path, 'rb')) data = dill.load(open(data_path, 'rb')) voc = dill.load(open(voc_path, 'rb')) diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[ 'med_voc'] # np.random.seed(2048) # np.random.shuffle(data) split_point = int(len(data) * 2 / 3) data_train = data[:split_point] eval_len = int(len(data[split_point:]) / 2) data_test = data[split_point:split_point + eval_len] data_eval = data[split_point + eval_len:] voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word)) model = GAMENet(voc_size, ehr_adj, ddi_adj, emb_dim=args.dim, device=device, ddi_in_memory=args.ddi) # model.load_state_dict(torch.load(open(args.resume_path, 'rb'))) if args.Test: model.load_state_dict(torch.load(open(args.resume_path, 'rb'))) model.to(device=device) tic = time.time() result = [] for _ in range(10): test_sample = np.random.choice(data_test, round(len(data_test) * 0.8), replace=True) ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval( model, test_sample, voc_size, 0) result.append([ddi_rate, ja, avg_f1, prauc, avg_med]) result = np.array(result) mean = result.mean(axis=0) std = result.std(axis=0) outstring = "" for m, s in zip(mean, std): outstring += "{:.4f} $\pm$ {:.4f} & ".format(m, s) print(outstring) print('test time: {}'.format(time.time() - tic)) return model.to(device=device) print('parameters', get_n_params(model)) optimizer = Adam(list(model.parameters()), lr=args.lr) history = defaultdict(list) best_epoch, best_ja = 0, 0 EPOCH = 50 for epoch in range(EPOCH): tic = time.time() print('\nepoch {} --------------------------'.format(epoch + 1)) prediction_loss_cnt, neg_loss_cnt = 0, 0 model.train() for step, input in enumerate(data_train): for idx, adm in enumerate(input): seq_input = input[:idx + 1] loss_bce_target = np.zeros((1, voc_size[2])) loss_bce_target[:, adm[2]] = 1 loss_multi_target = np.full((1, voc_size[2]), -1) for idx, item in enumerate(adm[2]): loss_multi_target[0][idx] = item target_output1, loss_ddi = model(seq_input) loss_bce = F.binary_cross_entropy_with_logits( target_output1, torch.FloatTensor(loss_bce_target).to(device)) loss_multi = F.multilabel_margin_loss( F.sigmoid(target_output1), torch.LongTensor(loss_multi_target).to(device)) if args.ddi: target_output1 = F.sigmoid( target_output1).detach().cpu().numpy()[0] target_output1[target_output1 >= 0.5] = 1 target_output1[target_output1 < 0.5] = 0 y_label = np.where(target_output1 == 1)[0] current_ddi_rate = ddi_rate_score( [[y_label]], path='../data/output/ddi_A_final.pkl') if current_ddi_rate <= args.target_ddi: loss = 0.9 * loss_bce + 0.1 * loss_multi prediction_loss_cnt += 1 else: rnd = np.exp( (args.target_ddi - current_ddi_rate) / args.T) if np.random.rand(1) < rnd: loss = loss_ddi neg_loss_cnt += 1 else: loss = 0.9 * loss_bce + 0.1 * loss_multi prediction_loss_cnt += 1 else: loss = 0.9 * loss_bce + 0.1 * loss_multi optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() llprint('\rtraining step: {} / {}'.format(step, len(data_train))) args.T *= args.decay_weight print() tic2 = time.time() ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval( model, data_eval, voc_size, epoch) print('training time: {}, test time: {}'.format( time.time() - tic, time.time() - tic2)) history['ja'].append(ja) history['ddi_rate'].append(ddi_rate) history['avg_p'].append(avg_p) history['avg_r'].append(avg_r) history['avg_f1'].append(avg_f1) history['prauc'].append(prauc) history['med'].append(avg_med) if epoch >= 5: print('ddi: {}, Med: {}, Ja: {}, F1: {}, PRAUC: {}'.format( np.mean(history['ddi_rate'][-5:]), np.mean(history['med'][-5:]), np.mean(history['ja'][-5:]), np.mean(history['avg_f1'][-5:]), np.mean(history['prauc'][-5:]))) torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \ 'Epoch_{}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, ja, ddi_rate)), 'wb')) if epoch != 0 and best_ja < ja: best_epoch = epoch best_ja = ja print('best_epoch: {}'.format(best_epoch)) dill.dump( history, open( os.path.join('saved', args.model_name, 'history_{}.pkl'.format(args.model_name)), 'wb'))
def main(): # load data data_path = '../data/output/records_final.pkl' voc_path = '../data/output/voc_final.pkl' ddi_adj_path = '../data/output/ddi_A_final.pkl' device = torch.device('cuda:{}'.format(args.cuda)) ddi_adj = dill.load(open(ddi_adj_path, 'rb')) data = dill.load(open(data_path, 'rb')) voc = dill.load(open(voc_path, 'rb')) diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[ 'med_voc'] np.random.seed(1203) np.random.shuffle(data) split_point = int(len(data) * 3 / 5) data_train = data[:split_point] eval_len = int(len(data[split_point:]) / 2) data_test = data[split_point:split_point + eval_len] data_eval = data[split_point + eval_len:] voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word)) model = MICRON(voc_size, ddi_adj, emb_dim=args.dim, device=device) # model.load_state_dict(torch.load(open(args.resume_path, 'rb'))) if args.Test: model.load_state_dict(torch.load(open(args.resume_path, 'rb'))) model.to(device=device) tic = time.time() label_list, prob_list = eval(model, data_eval, voc_size, 0, 1) threshold1, threshold2 = [], [] for i in range(label_list.shape[1]): _, _, boundary = roc_curve(label_list[:, i], prob_list[:, i], pos_label=1) # boundary1 should be in [0.5, 0.9], boundary2 should be in [0.1, 0.5] threshold1.append( min( 0.9, max(0.5, boundary[max(0, round(len(boundary) * 0.05) - 1)]))) threshold2.append( max( 0.1, min( 0.5, boundary[min(round(len(boundary) * 0.95), len(boundary) - 1)]))) print(np.mean(threshold1), np.mean(threshold2)) threshold1 = np.ones(voc_size[2]) * np.mean(threshold1) threshold2 = np.ones(voc_size[2]) * np.mean(threshold2) eval(model, data_test, voc_size, 0, 0, threshold1, threshold2) print('test time: {}'.format(time.time() - tic)) return model.to(device=device) print('parameters', get_n_params(model)) # exit() optimizer = RMSprop(list(model.parameters()), lr=args.lr, weight_decay=args.weight_decay) # start iterations history = defaultdict(list) best_epoch, best_ja = 0, 0 weight_list = [[0.25, 0.25, 0.25, 0.25]] EPOCH = 40 for epoch in range(EPOCH): t = 0 tic = time.time() print('\nepoch {} --------------------------'.format(epoch + 1)) sample_counter = 0 mean_loss = np.array([0, 0, 0, 0]) model.train() for step, input in enumerate(data_train): loss = 0 if len(input) < 2: continue for adm_idx, adm in enumerate(input): if adm_idx == 0: continue # sample_counter += 1 seq_input = input[:adm_idx + 1] loss_bce_target = np.zeros((1, voc_size[2])) loss_bce_target[:, adm[2]] = 1 loss_bce_target_last = np.zeros((1, voc_size[2])) loss_bce_target_last[:, input[adm_idx - 1][2]] = 1 loss_multi_target = np.full((1, voc_size[2]), -1) for idx, item in enumerate(adm[2]): loss_multi_target[0][idx] = item loss_multi_target_last = np.full((1, voc_size[2]), -1) for idx, item in enumerate(input[adm_idx - 1][2]): loss_multi_target_last[0][idx] = item result, result_last, _, loss_ddi, loss_rec = model(seq_input) loss_bce = 0.75 * F.binary_cross_entropy_with_logits(result, torch.FloatTensor(loss_bce_target).to(device)) + \ (1 - 0.75) * F.binary_cross_entropy_with_logits(result_last, torch.FloatTensor(loss_bce_target_last).to(device)) loss_multi = 5e-2 * (0.75 * F.multilabel_margin_loss(F.sigmoid(result), torch.LongTensor(loss_multi_target).to(device)) + \ (1 - 0.75) * F.multilabel_margin_loss(F.sigmoid(result_last), torch.LongTensor(loss_multi_target_last).to(device))) y_pred_tmp = F.sigmoid(result).detach().cpu().numpy()[0] y_pred_tmp[y_pred_tmp >= 0.5] = 1 y_pred_tmp[y_pred_tmp < 0.5] = 0 y_label = np.where(y_pred_tmp == 1)[0] current_ddi_rate = ddi_rate_score( [[y_label]], path='../data/output/ddi_A_final.pkl') # l2 = 0 # for p in model.parameters(): # l2 = l2 + (p ** 2).sum() if sample_counter == 0: lambda1, lambda2, lambda3, lambda4 = weight_list[-1] else: current_loss = np.array([ loss_bce.detach().cpu().numpy(), loss_multi.detach().cpu().numpy(), loss_ddi.detach().cpu().numpy(), loss_rec.detach().cpu().numpy() ]) current_ratio = (current_loss - np.array(mean_loss)) / np.array(mean_loss) instant_weight = np.exp(current_ratio) / sum( np.exp(current_ratio)) lambda1, lambda2, lambda3, lambda4 = instant_weight * 0.75 + np.array( weight_list[-1]) * 0.25 # update weight_list weight_list.append([lambda1, lambda2, lambda3, lambda4]) # update mean_loss mean_loss = (mean_loss * (sample_counter - 1) + np.array([loss_bce.detach().cpu().numpy(), \ loss_multi.detach().cpu().numpy(), loss_ddi.detach().cpu().numpy(), loss_rec.detach().cpu().numpy()])) / sample_counter # lambda1, lambda2, lambda3, lambda4 = weight_list[-1] if current_ddi_rate > 0.08: loss += lambda1 * loss_bce + lambda2 * loss_multi + \ lambda3 * loss_ddi + lambda4 * loss_rec else: loss += lambda1 * loss_bce + lambda2 * loss_multi + \ lambda4 * loss_rec optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() llprint('\rtraining step: {} / {}'.format(step, len(data_train))) tic2 = time.time() ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, add, delete, avg_med = eval( model, data_eval, voc_size, epoch) print('training time: {}, test time: {}'.format( time.time() - tic, time.time() - tic2)) history['ja'].append(ja) history['ddi_rate'].append(ddi_rate) history['avg_p'].append(avg_p) history['avg_r'].append(avg_r) history['avg_f1'].append(avg_f1) history['prauc'].append(prauc) history['add'].append(add) history['delete'].append(delete) history['med'].append(avg_med) if epoch >= 5: print( 'ddi: {}, Med: {}, Ja: {}, F1: {}, Add: {}, Delete: {}'.format( np.mean(history['ddi_rate'][-5:]), np.mean(history['med'][-5:]), np.mean(history['ja'][-5:]), np.mean(history['avg_f1'][-5:]), np.mean(history['add'][-5:]), np.mean(history['delete'][-5:]))) torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \ 'Epoch_{}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, ja, ddi_rate)), 'wb')) if epoch != 0 and best_ja < ja: best_epoch = epoch best_ja = ja print('best_epoch: {}'.format(best_epoch)) dill.dump( history, open( os.path.join('saved', args.model_name, 'history_{}.pkl'.format(args.model_name)), 'wb'))
def eval(model, data_eval, voc_size, epoch): model.eval() smm_record = [] ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] med_cnt, visit_cnt = 0, 0 add_list, delete_list = [], [] for step, input in enumerate(data_eval): y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], [] if len(input) < 2: continue add_temp_list, delete_temp_list = [], [] previous_set = input[0][2] for i in range(1, len(input)): target_output = model(input[:i]) y_gt_tmp = np.zeros(voc_size[2]) y_gt_tmp[input[i][2]] = 1 y_gt.append(y_gt_tmp) # prediction prob target_output = F.sigmoid(target_output).detach().cpu().numpy()[0] y_pred_prob.append(target_output) # prediction med set y_pred_tmp = target_output.copy() y_pred_tmp[y_pred_tmp >= 0.3] = 1 y_pred_tmp[y_pred_tmp < 0.3] = 0 y_pred.append(y_pred_tmp) # prediction label y_pred_label_tmp = np.where(y_pred_tmp == 1)[0] y_pred_label.append(y_pred_label_tmp) med_cnt += len(y_pred_label_tmp) visit_cnt += 1 #### add or delete add_gt = set(np.where(y_gt_tmp == 1)[0]) - set(previous_set) delete_gt = set(previous_set) - set(np.where(y_gt_tmp == 1)[0]) add_pre = set(np.where(y_pred_tmp == 1)[0]) - set(previous_set) delete_pre = set(previous_set) - set(np.where(y_pred_tmp == 1)[0]) add_distance = len(set(add_pre) - set(add_gt)) + len(set(add_gt) - set(add_pre)) delete_distance = len(set(delete_pre) - set(delete_gt)) + len( set(delete_gt) - set(delete_pre)) #### add_temp_list.append(add_distance) delete_temp_list.append(delete_distance) previous_set = y_pred_label_tmp if len(add_temp_list) > 1: add_list.append(np.mean(add_temp_list)) delete_list.append(np.mean(delete_temp_list)) elif len(add_temp_list) == 1: add_list.append(add_temp_list[0]) delete_list.append(delete_temp_list[0]) smm_record.append(y_pred_label) adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric( np.array(y_gt), np.array(y_pred), np.array(y_pred_prob)) ja.append(adm_ja) prauc.append(adm_prauc) avg_p.append(adm_avg_p) avg_r.append(adm_avg_r) avg_f1.append(adm_avg_f1) llprint('\rtest step: {} / {}'.format(step, len(data_eval))) # ddi rate ddi_rate = ddi_rate_score(smm_record, path='../data/output/ddi_A_final.pkl') llprint( '\nDDI Rate: {:.4}, Jaccard: {:.4}, AVG_F1: {:.4}, Add: {:.4}, Delete: {:.4}, AVG_MED: {:.4}\n' .format(np.float(ddi_rate), np.mean(ja), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt)) return np.float(ddi_rate), np.mean(ja), np.mean(prauc), np.mean( avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean(add_list), np.mean( delete_list), med_cnt / visit_cnt
def eval(model, data_eval, voc_size, epoch): model.eval() ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)] smm_record = [] med_cnt, visit_cnt = 0, 0 add_list, delete_list = [], [] for step, input in enumerate(data_eval): y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], [] if len(input) < 2: continue add_temp_list, delete_temp_list = [], [] for adm_idx, adm in enumerate(input): if adm_idx == 0: previous_set = adm[2] continue output_logits = model(adm) y_gt_tmp = np.zeros(voc_size[2]) y_gt_tmp[adm[2]] = 1 y_gt.append(y_gt_tmp) # prediction prod output_logits = output_logits.detach().cpu().numpy() # prediction med set out_list, sorted_predict = sequence_output_process(output_logits, [voc_size[2], voc_size[2]+1]) y_pred_label.append(sorted(sorted_predict)) y_pred_prob.append(np.mean(output_logits[:, :-2], axis=0)) # prediction label y_pred_tmp = np.zeros(voc_size[2]) y_pred_tmp[out_list] = 1 y_pred.append(y_pred_tmp) visit_cnt += 1 med_cnt += len(sorted_predict) #### add or delete add_gt = set(np.where(y_gt_tmp == 1)[0]) - set(previous_set) delete_gt = set(previous_set) - set(np.where(y_gt_tmp == 1)[0]) add_pre = set(np.where(y_pred_tmp == 1)[0]) - set(previous_set) delete_pre = set(previous_set) - set(np.where(y_pred_tmp == 1)[0]) add_distance = len(set(add_pre) - set(add_gt)) + len(set(add_gt) - set(add_pre)) delete_distance = len(set(delete_pre) - set(delete_gt)) + len(set(delete_gt) - set(delete_pre)) #### add_temp_list.append(add_distance) delete_temp_list.append(delete_distance) previous_temp_set = out_list if len(add_temp_list) > 1: add_list.append(np.mean(add_temp_list)) delete_list.append(np.mean(delete_temp_list)) else: add_list.append(add_temp_list[0]) delete_list.append(delete_temp_list[0]) smm_record.append(y_pred_label) adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = \ sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label)) ja.append(adm_ja) prauc.append(adm_prauc) avg_p.append(adm_avg_p) avg_r.append(adm_avg_r) avg_f1.append(adm_avg_f1) llprint('\rtest step: {} / {}'.format(step, len(data_eval))) # ddi rate ddi_rate = ddi_rate_score(smm_record, path='../data/output/ddi_A_final.pkl') llprint('\nDDI Rate: {:.4}, Jaccard: {:.4}, AVG_F1: {:.4}, Add: {:.4}, Delete: {:.4}, AVG_MED: {:.4}\n'.format( np.float(ddi_rate), np.mean(ja), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt )) return np.float(ddi_rate), np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt