Exemple #1
0
def eval(model, data_eval, voc_size, epoch):
    model.eval()

    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    smm_record = []
    med_cnt, visit_cnt = 0, 0

    for step, input in enumerate(data_eval):
        y_gt = []
        y_pred = []
        y_pred_prob = []
        y_pred_label = []

        for adm_index, adm in enumerate(input):
            output_logits = model(adm)

            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[adm[2]] = 1
            y_gt.append(y_gt_tmp)

            # prediction prod
            output_logits = output_logits.detach().cpu().numpy()

            # prediction med set
            out_list, sorted_predict = sequence_output_process(
                output_logits, [voc_size[2], voc_size[2] + 1])
            y_pred_label.append(sorted(sorted_predict))
            y_pred_prob.append(np.mean(output_logits[:, :-2], axis=0))

            # prediction label
            y_pred_tmp = np.zeros(voc_size[2])
            y_pred_tmp[out_list] = 1
            y_pred.append(y_pred_tmp)
            visit_cnt += 1
            med_cnt += len(sorted_predict)

        smm_record.append(y_pred_label)

        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = \
                sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label))
        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rtest step: {} / {}'.format(step, len(data_eval)))

    # ddi rate
    ddi_rate = ddi_rate_score(smm_record,
                              path='../data/output/ddi_A_final.pkl')

    llprint(
        '\nDDI Rate: {:.4}, Jaccard: {:.4},  PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n'
        .format(ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p),
                np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt))

    return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(
        avg_r), np.mean(avg_f1), med_cnt / visit_cnt
Exemple #2
0
def eval(model, data_eval, voc_size, epoch):
    # evaluate
    print('')
    model.eval()

    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    records = []
    med_cnt = 0
    visit_cnt = 0
    for step, input in enumerate(data_eval):
        y_gt = []
        y_pred = []
        y_pred_prob = []
        y_pred_label = []
        for adm in input:
            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[adm[2]] = 1
            y_gt.append(y_gt_tmp)

            output_logits = model(adm)
            output_logits = output_logits.detach().cpu().numpy()

            out_list, sorted_predict = sequence_output_process(
                output_logits, [voc_size[2], voc_size[2] + 1])

            y_pred_label.append(sorted(sorted_predict))
            y_pred_prob.append(np.mean(output_logits[:, :-2], axis=0))

            y_pred_tmp = np.zeros(voc_size[2])
            y_pred_tmp[out_list] = 1
            y_pred.append(y_pred_tmp)
            visit_cnt += 1
            med_cnt += len(sorted_predict)
        records.append(y_pred_label)

        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = sequence_metric(
            np.array(y_gt), np.array(y_pred), np.array(y_pred_prob),
            np.array(y_pred_label))
        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rEval--Epoch: %d, Step: %d/%d' %
                (epoch, step, len(data_eval)))

    # ddi rate
    ddi_rate = ddi_rate_score(records)
    llprint(
        '\tDDI Rate: %.4f, Jaccard: %.4f,  PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f\n'
        % (ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p),
           np.mean(avg_r), np.mean(avg_f1)))
    print('avg med', med_cnt / visit_cnt)
    return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(
        avg_r), np.mean(avg_f1)
Exemple #3
0
def fine_tune(fine_tune_name=''):

    # load data
    data_path = '../data/output/records_final.pkl'
    voc_path = '../data/output/voc_final.pkl'
    device = torch.device('cpu:0')

    data = dill.load(open(data_path, 'rb'))
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[
        'med_voc']
    ddi_A = dill.load(open('../data/output/ddi_A_final.pkl', 'rb'))

    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    # data_eval = data[split_point+eval_len:]
    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word),
                len(med_voc.idx2word))

    model = Leap(voc_size, device=device)
    model.load_state_dict(
        torch.load(
            open(os.path.join("saved", args.model_name, fine_tune_name),
                 'rb')))
    model.to(device)

    END_TOKEN = voc_size[2] + 1

    optimizer = Adam(model.parameters(), lr=args.lr)
    ddi_rate_record = []

    EPOCH = 100
    for epoch in range(EPOCH):
        loss_record = []
        start_time = time.time()
        random_train_set = [
            random.choice(data_train) for i in range(len(data_train))
        ]
        for step, input in enumerate(random_train_set):
            model.train()
            K_flag = False
            for adm in input:
                target = adm[2]
                output_logits = model(adm)
                out_list, sorted_predict = sequence_output_process(
                    output_logits.detach().cpu().numpy(),
                    [voc_size[2], voc_size[2] + 1])

                inter = set(out_list) & set(target)
                union = set(out_list) | set(target)
                jaccard = 0 if union == 0 else len(inter) / len(union)
                K = 0
                for i in out_list:
                    if K == 1:
                        K_flag = True
                        break
                    for j in out_list:
                        if ddi_A[i][j] == 1:
                            K = 1
                            break

                loss = -jaccard * K * torch.mean(
                    F.log_softmax(output_logits, dim=-1))
                loss_record.append(loss.item())
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

            llprint('\rtraining step: {} / {}'.format(step,
                                                      len(random_train_set)))

        if K_flag:
            print()
            ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(
                model, data_test, voc_size, epoch)

    # test
    torch.save(
        model.state_dict(),
        open(os.path.join('saved', args.model_name, 'final.model'), 'wb'))
Exemple #4
0
def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):
    start = time.time()
    print_loss_total = 0  # Reset every print_every

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)

    training_pairs = [tensorsFromPair(random.choice(train_pairs))
                      for i in range(n_iters)]
    criterion = nn.CrossEntropyLoss()
    history = defaultdict(list)
    for epoch in range(30):
        for iter in range(1, n_iters + 1):
            training_pair = training_pairs[iter - 1]
            input_tensor = training_pair[0]
            target_tensor = training_pair[1]

            loss = train(input_tensor, target_tensor, encoder,
                        decoder, encoder_optimizer, decoder_optimizer, criterion)
            print_loss_total += loss
            llprint('\rTrain--Epoch: %d, Step: %d/%d' % (epoch, iter, n_iters))

        print_loss_avg = print_loss_total / n_iters
        print_loss_total = 0

        #eval
        y_gt = []
        y_pred = []
        y_pred_prob = []
        y_pred_label = []
        for pair in eval_pairs:
            y_gt_tmp = np.zeros(len(med_voc.idx2word))
            y_gt_tmp[np.array(pair[1])[:-1]-2] = 1
            y_gt.append(y_gt_tmp)

            input_tensor, output_tensor = tensorsFromPair(pair)
            output_logits = evaluate(encoder, decoder, input_tensor)
            output_logits = F.softmax(output_logits)
            output_logits = output_logits.detach().cpu().numpy()
            out_list, sorted_predict = sequence_output_process(output_logits, [SOS_token, EOS_token])

            y_pred_label.append(np.array(sorted_predict)-2)
            y_pred_prob.append(np.mean(output_logits[:, 2:], axis=0))

            y_pred_tmp = np.zeros(len(med_voc.idx2word))
            if len(out_list) != 0 :
                y_pred_tmp[np.array(out_list) - 2] = 1
            y_pred.append(y_pred_tmp)

        ja, prauc, avg_p, avg_r, avg_f1 = sequence_metric(np.array(y_gt), np.array(y_pred),
                                                        np.array(y_pred_prob),
                                                        np.array(y_pred_label))
        # ddi rate
        ddi_A = dill.load(open('../data/ddi_A_final.pkl', 'rb'))
        all_cnt = 0
        dd_cnt = 0
        for adm in y_pred_label:
            med_code_set = adm
            for i, med_i in enumerate(med_code_set):
                for j, med_j in enumerate(med_code_set):
                    if j <= i:
                        continue
                    all_cnt += 1
                    if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1:
                        dd_cnt += 1
        ddi_rate = dd_cnt / all_cnt

        history['ja'].append(ja)
        history['ddi_rate'].append(ddi_rate)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)
        llprint('\n\tDDI Rate: %.4f, Jaccard: %.4f,  PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f\n' % (
            ddi_rate, ja, prauc, avg_p, avg_r, avg_f1
        ))

        dill.dump(history, open(os.path.join('saved', model_name, 'history.pkl'), 'wb'))

        torch.save(encoder.state_dict(),
                   open(
                       os.path.join('saved', model_name, 'encoder_Epoch_%d_JA_%.4f_DDI_%.4f.model' % (epoch, ja, dd_cnt/all_cnt)),
                       'wb'))
        torch.save(decoder.state_dict(),
                   open(
                       os.path.join('saved', model_name, 'decoder_Epoch_%d_JA_%.4f_DDI_%.4f.model' % (epoch, ja, dd_cnt/all_cnt)),
                       'wb'))
Exemple #5
0
                    MAX_LEN = len(input_seq)
                output_seq = list(np.array(o) + 2)
                output_seq.append(EOS_token)

                test_pairs.append((input_seq, output_seq))

    for pair in test_pairs:
        y_gt_tmp = np.zeros(len(med_voc.idx2word))
        y_gt_tmp[np.array(pair[1])[:-1] - 2] = 1
        y_gt.append(y_gt_tmp)

        input_tensor, output_tensor = tensorsFromPair(pair)
        output_logits = evaluate(encoder1, decoder1, input_tensor)
        output_logits = F.softmax(output_logits)
        output_logits = output_logits.detach().cpu().numpy()
        out_list, sorted_predict = sequence_output_process(output_logits, [SOS_token, EOS_token])

        y_pred_label.append(np.array(sorted_predict) - 2)
        y_pred_prob.append(np.mean(output_logits[:, 2:], axis=0))

        y_pred_tmp = np.zeros(len(med_voc.idx2word))
        if len(out_list) != 0:
            y_pred_tmp[np.array(out_list) - 2] = 1
        y_pred.append(y_pred_tmp)

    ja, prauc, avg_p, avg_r, avg_f1 = sequence_metric(np.array(y_gt), np.array(y_pred),
                                                    np.array(y_pred_prob),
                                                    np.array(y_pred_label))
    # ddi rate
    ddi_A = dill.load(open('../data/ddi_A_final.pkl', 'rb'))
    all_cnt = 0
Exemple #6
0
def eval(model, data_eval, voc_size, epoch):
    model.eval()

    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    smm_record = []
    med_cnt, visit_cnt = 0, 0
    add_list, delete_list = [], []

    for step, input in enumerate(data_eval):
        y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], []
        if len(input) < 2: continue
        add_temp_list, delete_temp_list = [], []

        for adm_idx, adm in enumerate(input):
            if adm_idx == 0: 
                previous_set = adm[2] 
                continue
            output_logits = model(adm)

            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[adm[2]] = 1
            y_gt.append(y_gt_tmp)

            # prediction prod
            output_logits = output_logits.detach().cpu().numpy()

            # prediction med set
            out_list, sorted_predict = sequence_output_process(output_logits, [voc_size[2], voc_size[2]+1])
            y_pred_label.append(sorted(sorted_predict))
            y_pred_prob.append(np.mean(output_logits[:, :-2], axis=0))

            # prediction label
            y_pred_tmp = np.zeros(voc_size[2])
            y_pred_tmp[out_list] = 1
            y_pred.append(y_pred_tmp)
            visit_cnt += 1
            med_cnt += len(sorted_predict)

            #### add or delete
            add_gt = set(np.where(y_gt_tmp == 1)[0]) - set(previous_set)
            delete_gt = set(previous_set) - set(np.where(y_gt_tmp == 1)[0])

            add_pre = set(np.where(y_pred_tmp == 1)[0]) - set(previous_set)
            delete_pre = set(previous_set) - set(np.where(y_pred_tmp == 1)[0])
            
            add_distance = len(set(add_pre) - set(add_gt)) + len(set(add_gt) - set(add_pre))
            delete_distance = len(set(delete_pre) - set(delete_gt)) + len(set(delete_gt) - set(delete_pre))
            ####

            add_temp_list.append(add_distance)
            delete_temp_list.append(delete_distance)

            previous_temp_set = out_list

        if len(add_temp_list) > 1:
            add_list.append(np.mean(add_temp_list))
            delete_list.append(np.mean(delete_temp_list))
        else:
            add_list.append(add_temp_list[0])
            delete_list.append(delete_temp_list[0])

        smm_record.append(y_pred_label)

        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = \
                sequence_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob), np.array(y_pred_label))
        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rtest step: {} / {}'.format(step, len(data_eval)))

    # ddi rate
    ddi_rate = ddi_rate_score(smm_record, path='../data/output/ddi_A_final.pkl')

    llprint('\nDDI Rate: {:.4}, Jaccard: {:.4},  AVG_F1: {:.4}, Add: {:.4}, Delete: {:.4}, AVG_MED: {:.4}\n'.format(
        np.float(ddi_rate), np.mean(ja), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt
    ))

    return np.float(ddi_rate), np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt
Exemple #7
0
def fine_tune(fine_tune_name=''):
    data_path = '../../data/records_final.pkl'
    voc_path = '../../data/voc_final.pkl'
    device = torch.device('cuda:0')

    data = dill.load(open(data_path, 'rb'))
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[
        'med_voc']
    ddi_A = dill.load(open('../../data/ddi_A_final.pkl', 'rb'))

    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    # data_eval = data[split_point+eval_len:]
    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word),
                len(med_voc.idx2word))

    model = Leap(voc_size, device=device)
    model.load_state_dict(
        torch.load(
            open(os.path.join("saved", model_name, fine_tune_name), 'rb')))
    model.to(device)

    EPOCH = 30
    LR = 0.0001
    END_TOKEN = voc_size[2] + 1

    optimizer = Adam(model.parameters(), lr=LR)
    ddi_rate_record = []
    for epoch in range(1):
        loss_record = []
        start_time = time.time()
        random_train_set = [
            random.choice(data_train) for i in range(len(data_train))
        ]
        for step, input in enumerate(random_train_set):
            model.train()
            K_flag = False
            for adm in input:
                target = adm[2]
                output_logits = model(adm)
                out_list, sorted_predict = sequence_output_process(
                    output_logits.detach().cpu().numpy(),
                    [voc_size[2], voc_size[2] + 1])

                inter = set(out_list) & set(target)
                union = set(out_list) | set(target)
                jaccard = 0 if union == 0 else len(inter) / len(union)
                K = 0
                for i in out_list:
                    if K == 1:
                        K_flag = True
                        break
                    for j in out_list:
                        if ddi_A[i][j] == 1:
                            K = 1
                            break

                loss = -jaccard * K * torch.mean(
                    F.log_softmax(output_logits, dim=-1))

                loss_record.append(loss.item())

                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

            llprint('\rTrain--Epoch: %d, Step: %d/%d' %
                    (epoch, step, len(data_train)))

            if K_flag:
                ddi_rate, ja, prauc, avg_p, avg_r, avg_f1 = eval(
                    model, data_test, voc_size, epoch)

                end_time = time.time()
                elapsed_time = (end_time - start_time) / 60
                llprint(
                    '\tEpoch: %d, Loss1: %.4f, One Epoch Time: %.2fm, Appro Left Time: %.2fh\n'
                    %
                    (epoch, np.mean(loss_record), elapsed_time, elapsed_time *
                     (EPOCH - epoch - 1) / 60))

                torch.save(
                    model.state_dict(),
                    open(
                        os.path.join(
                            'saved', model_name,
                            'fine_Epoch_%d_JA_%.4f_DDI_%.4f.model' %
                            (epoch, ja, ddi_rate)), 'wb'))
                print('')

    # test
    torch.save(model.state_dict(),
               open(os.path.join('saved', model_name, 'final.model'), 'wb'))