Ejemplo n.º 1
0
def eval(model, data_eval, voc_size, epoch):
    # evaluate
    print('')
    model.eval()
    smm_record = []
    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    case_study = defaultdict(dict)
    med_cnt = 0
    visit_cnt = 0
    for step, input in enumerate(data_eval):
        if len(input) < 2: # visit > 2
            continue
        y_gt = []
        y_pred = []
        y_pred_prob = []
        y_pred_label = []
        for i in range(1, len(input)):

            y_pred_label_tmp = []
            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[input[i][2]] = 1
            y_gt.append(y_gt_tmp)

            target_output1 = model(input[:i])

            target_output1 = F.sigmoid(target_output1).detach().cpu().numpy()[0]
            y_pred_prob.append(target_output1)
            y_pred_tmp = target_output1.copy()
            y_pred_tmp[y_pred_tmp >= 0.3] = 1
            y_pred_tmp[y_pred_tmp < 0.3] = 0
            y_pred.append(y_pred_tmp)
            for idx, value in enumerate(y_pred_tmp):
                if value == 1:
                    y_pred_label_tmp.append(idx)
            y_pred_label.append(y_pred_label_tmp)
            med_cnt += len(y_pred_label_tmp)
            visit_cnt += 1

        smm_record.append(y_pred_label)
        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(np.array(y_gt), np.array(y_pred),
                                                                                   np.array(y_pred_prob))
        case_study[adm_ja] = {'ja': adm_ja, 'patient':input, 'y_label':y_pred_label}
        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rEval--Epoch: %d, Step: %d/%d' % (epoch, step, len(data_eval)))

    dill.dump(case_study, open(os.path.join('saved', model_name, 'case_study.pkl'), 'wb'))
    # ddi rate
    ddi_rate = ddi_rate_score(smm_record)

    llprint('\tDDI Rate: %.4f, Jaccard: %.4f,  PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f\n' % (
        ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1)
    ))
    print('avg med', med_cnt / visit_cnt)

    return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1)
Ejemplo n.º 2
0
def eval(model, data_eval, voc_size, epoch):
    # evaluate
    print('')
    model.eval()
    smm_record = []
    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    for step, input in enumerate(data_eval):
        y_gt = []
        y_pred = []
        y_pred_prob = []
        y_pred_label = []
        input1_hidden, input2_hidden, target_hidden = None, None, None
        prev_target = None
        for adm_idx, adm in enumerate(input):

            target_output1, [input1_hidden,
                             input2_hidden, target_hidden] = model(
                                 adm, prev_target,
                                 [input1_hidden, input2_hidden, target_hidden])
            prev_target = adm[2]

            y_pred_label_tmp = []
            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[adm[2]] = 1
            y_gt.append(y_gt_tmp)

            target_output1 = F.sigmoid(
                target_output1).detach().cpu().numpy()[0]
            y_pred_prob.append(target_output1)
            y_pred_tmp = target_output1.copy()
            y_pred_tmp[y_pred_tmp >= 0.5] = 1
            y_pred_tmp[y_pred_tmp < 0.5] = 0
            y_pred.append(y_pred_tmp)
            for idx, value in enumerate(y_pred_tmp):
                if value == 1:
                    y_pred_label_tmp.append(idx)
            y_pred_label.append(y_pred_label_tmp)
        smm_record.append(y_pred_label)
        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(
            np.array(y_gt), np.array(y_pred), np.array(y_pred_prob))
        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rEval--Epoch: %d, Step: %d/%d' %
                (epoch, step, len(data_eval)))
    # ddi rate
    ddi_rate = ddi_rate_score(smm_record)

    llprint(
        '\tDDI Rate: %.4f, Jaccard: %.4f,  PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f\n'
        % (ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p),
           np.mean(avg_r), np.mean(avg_f1)))

    return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(
        avg_r), np.mean(avg_f1)
Ejemplo n.º 3
0
def eval(model, data_eval, voc_size, epoch):
    model.eval()

    smm_record = []
    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    med_cnt, visit_cnt = 0, 0

    for step, input in enumerate(data_eval):
        y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], []

        if len(input) < 2: continue
        for i in range(1, len(input)):
            target_output = model(input[:i])

            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[input[i][2]] = 1
            y_gt.append(y_gt_tmp)

            # prediction prob
            target_output = F.sigmoid(target_output).detach().cpu().numpy()[0]
            y_pred_prob.append(target_output)

            # prediction med set
            y_pred_tmp = target_output.copy()
            y_pred_tmp[y_pred_tmp >= 0.4] = 1
            y_pred_tmp[y_pred_tmp < 0.4] = 0
            y_pred.append(y_pred_tmp)

            # prediction label
            y_pred_label_tmp = np.where(y_pred_tmp == 1)[0]
            y_pred_label.append(y_pred_label_tmp)
            med_cnt += len(y_pred_label_tmp)
            visit_cnt += 1

        smm_record.append(y_pred_label)
        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 =\
                multi_label_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob))

        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rtest step: {} / {}'.format(step, len(data_eval)))

    # ddi rate
    ddi_rate = ddi_rate_score(smm_record,
                              path='../data/output/ddi_A_final.pkl')

    llprint(
        '\nDDI Rate: {:.4}, Jaccard: {:.4},  PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n'
        .format(ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p),
                np.mean(avg_r), np.mean(avg_f1), med_cnt / visit_cnt))

    return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(
        avg_r), np.mean(avg_f1), med_cnt / visit_cnt
Ejemplo n.º 4
0
def eval(model, data_eval, voc_size, epoch):
    # evaluate
    print('')
    model.eval()
    smm_record = []
    auc, p_1, p_3, p_5, f1, prauc = [[] for _ in range(6)]
    for step, input in enumerate(data_eval):
        y_gt = []
        y_pred = []
        y_pred_prob = []
        y_pred_label = []
        input1_hidden, input2_hidden, target_hidden = None, None, None
        for adm in input:
            y_pred_label_tmp = []
            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[adm[2]] = 1
            y_gt.append(y_gt_tmp)

            target_output1, output_logits, output_labels, [
                input1_hidden, input2_hidden, target_hidden
            ] = model(adm, [input1_hidden, input2_hidden, target_hidden])

            target_output1 = F.sigmoid(
                target_output1).detach().cpu().numpy()[0]
            a = np.argsort(target_output1)[::-1]
            b = np.max(output_logits, axis=-1)
            y_pred_prob.append(target_output1)
            y_pred_tmp = target_output1.copy()
            y_pred_tmp[y_pred_tmp >= 0.5] = 1
            y_pred_tmp[y_pred_tmp < 0.5] = 0
            y_pred.append(y_pred_tmp)
            for idx, value in enumerate(y_pred_tmp):
                if value == 1:
                    y_pred_label_tmp.append(idx)
            y_pred_label.append(y_pred_label_tmp)
        smm_record.append(y_pred_label)
        adm_auc, adm_p_1, adm_p_3, adm_p_5, adm_f1, adm_prauc = multi_label_metric(
            np.array(y_gt), np.array(y_pred), np.array(y_pred_prob))
        auc.append(adm_auc)
        p_1.append(adm_p_1)
        p_3.append(adm_p_3)
        p_5.append(adm_p_5)
        f1.append(adm_f1)
        prauc.append(adm_prauc)
        llprint('\rEval--Epoch: %d, Step: %d/%d' %
                (epoch, step, len(data_eval)))

    llprint(
        '\tAUC: %.4f, P1: %.4f, P3: %.4f, P5: %.4f, F1: %.4f, PRAUC: %.4f\n' %
        (np.mean(auc), np.mean(p_1), np.mean(p_3), np.mean(p_5), np.mean(f1),
         np.mean(prauc)))
    dill.dump(obj=smm_record, file=open('../data/smm_records.pkl', 'wb'))
Ejemplo n.º 5
0
def main():
    gt = []
    pred = []
    for patient in data_test:
        if len(patient) == 1:
            continue
        for adm_idx, adm in enumerate(patient):
            if adm_idx < len(patient) - 1:
                gt.append(patient[adm_idx + 1][2])
                pred.append(adm[2])
    med_voc_size = len(med_voc.idx2word)
    y_gt = np.zeros((len(gt), med_voc_size))
    y_pred = np.zeros((len(gt), med_voc_size))
    for idx, item in enumerate(gt):
        y_gt[idx, item] = 1
    for idx, item in enumerate(pred):
        y_pred[idx, item] = 1

    ja, prauc, avg_p, avg_r, avg_f1 = multi_label_metric(y_gt, y_pred, y_pred)

    # ddi rate
    ddi_A = dill.load(open(ddi_adj_path, 'rb'))
    all_cnt = 0
    dd_cnt = 0
    med_cnt = 0
    visit_cnt = 0
    for adm in y_pred:
        med_code_set = np.where(adm == 1)[0]
        visit_cnt += 1
        med_cnt += len(med_code_set)
        for i, med_i in enumerate(med_code_set):
            for j, med_j in enumerate(med_code_set):
                if j <= i:
                    continue
                all_cnt += 1
                if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1:
                    dd_cnt += 1
    ddi_rate = dd_cnt / all_cnt
    print(
        '\tDDI Rate: %.4f, Jaccard: %.4f,  PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f\n'
        % (ddi_rate, ja, prauc, avg_p, avg_r, avg_f1))
    print('avg med', med_cnt / visit_cnt)
Ejemplo n.º 6
0
def eval(model,
         data_eval,
         voc_size,
         epoch,
         val=0,
         threshold1=0.3,
         threshold2=0.3):
    model.eval()

    smm_record = []
    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    med_cnt, visit_cnt = 0, 0
    label_list, prob_add, prob_delete = [], [], []
    add_list, delete_list = [], []

    for step, input in enumerate(data_eval):
        y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], []
        add_temp_list, delete_temp_list = [], []

        if len(input) < 2: continue
        for adm_idx, adm in enumerate(input):
            if adm_idx == 0:
                y_old = np.zeros(voc_size[2])
                y_old[adm[2]] = 1
                continue

            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[adm[2]] = 1
            y_gt.append(y_gt_tmp)
            label_list.append(y_gt_tmp)

            add_result, delete_result = model(input[:adm_idx + 1])
            # prediction prod
            y_pred_tmp_add = F.sigmoid(add_result).detach().cpu().numpy()[0]
            y_pred_tmp_delete = F.sigmoid(
                delete_result).detach().cpu().numpy()[0]
            y_pred_prob.append(y_pred_tmp_add)
            prob_add.append(y_pred_tmp_add)
            prob_delete.append(y_pred_tmp_delete)

            previous_set = np.where(y_old == 1)[0]

            # prediction med set
            y_old[y_pred_tmp_add >= threshold2] = 1
            y_old[y_pred_tmp_delete >= threshold1] = 0
            y_pred.append(y_old)

            # prediction label
            y_pred_label_tmp = np.where(y_old == 1)[0]
            y_pred_label.append(sorted(y_pred_label_tmp))
            visit_cnt += 1
            med_cnt += len(y_pred_label_tmp)

            #### add or delete
            add_gt = set(np.where(y_gt_tmp == 1)[0]) - set(previous_set)
            delete_gt = set(previous_set) - set(np.where(y_gt_tmp == 1)[0])

            add_pre = set(np.where(y_old == 1)[0]) - set(previous_set)
            delete_pre = set(previous_set) - set(np.where(y_old == 1)[0])

            add_distance = len(set(add_pre) -
                               set(add_gt)) + len(set(add_gt) - set(add_pre))
            delete_distance = len(set(delete_pre) - set(delete_gt)) + len(
                set(delete_gt) - set(delete_pre))
            ####

            add_temp_list.append(add_distance)
            delete_temp_list.append(delete_distance)

        if len(add_temp_list) > 1:
            add_list.append(np.mean(add_temp_list))
            delete_list.append(np.mean(delete_temp_list))
        elif len(add_temp_list) == 1:
            add_list.append(add_temp_list[0])
            delete_list.append(delete_temp_list[0])

        smm_record.append(y_pred_label)
        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(
            np.array(y_gt), np.array(y_pred), np.array(y_pred_prob))

        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rtest step: {} / {}'.format(step, len(data_eval)))

    # ddi rate
    ddi_rate = ddi_rate_score(smm_record,
                              path='../data/output/ddi_A_final.pkl')

    # llprint('\nDDI Rate: {:.4}, Jaccard: {:.4},  PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, Add: {:.4}, Delete; {:.4}, AVG_MED: {:.4}\n'.format(
    #     ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt
    # ))
    # print ('-1-', ddi_rate, '-2-',  np.mean(ja), '-3-', np.mean(prauc), '-4-', np.mean(avg_f1), '-5-', np.mean(add_list), '-6-', np.mean(delete_list), '-7-', med_cnt / visit_cnt)
    llprint(
        '\nDDI Rate: {:.4}, Jaccard: {:.4},  AVG_F1: {:.4}, Add: {:.4}, Delete: {:.4}, AVG_MED: {:.4}\n'
        .format(np.float(ddi_rate), np.mean(ja), np.mean(avg_f1),
                np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt))
    if val == 0:
        return np.float(ddi_rate), np.mean(ja), np.mean(prauc), np.mean(
            avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean(
                add_list), np.mean(delete_list), med_cnt / visit_cnt
    else:
        return np.array(label_list), np.array(prob_add), np.array(prob_delete)
Ejemplo n.º 7
0
def eval(model,
         data_eval,
         voc_size,
         epoch,
         val=0,
         threshold1=0.8,
         threshold2=0.2):
    model.eval()

    smm_record = []
    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    med_cnt, visit_cnt = 0, 0
    add_list, delete_list = [], []

    for step, input in enumerate(data_eval):
        y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], []
        add_temp_list, delete_temp_list = [], []
        if len(input) < 2: continue
        for adm_idx, adm in enumerate(input):
            if adm_idx == 0:
                y_old = np.zeros(voc_size[2])
                y_old[adm[2]] = 1
                continue

            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[adm[2]] = 1
            y_gt.append(y_gt_tmp)

            result_out = model(input[:adm_idx + 1])
            # prediction prod
            y_pred_tmp = F.sigmoid(
                result_out[:, 0]).detach().cpu().numpy().tolist()
            y_pred_prob.append(y_pred_tmp)

            previous_set = np.where(y_old == 1)[0]

            # prediction med set
            # result = F.sigmoid(result).detach().cpu().numpy()[0]
            assignment = torch.max(result_out, axis=1)[1].cpu().numpy()
            y_old[assignment == 1] = 1
            y_old[assignment == 2] = 0
            y_pred.append(y_old)

            # prediction label
            y_pred_label_tmp = np.where(y_old == 1)[0]
            y_pred_label.append(sorted(y_pred_label_tmp))
            visit_cnt += 1
            med_cnt += len(y_pred_label_tmp)

            #### add or delete
            add_gt = set(np.where(y_gt_tmp == 1)[0]) - set(previous_set)
            delete_gt = set(previous_set) - set(np.where(y_gt_tmp == 1)[0])

            add_pre = set(np.where(y_old == 1)[0]) - set(previous_set)
            delete_pre = set(previous_set) - set(np.where(y_old == 1)[0])

            add_distance = len(set(add_pre) -
                               set(add_gt)) + len(set(add_gt) - set(add_pre))
            delete_distance = len(set(delete_pre) - set(delete_gt)) + len(
                set(delete_gt) - set(delete_pre))
            ####

            add_temp_list.append(add_distance)
            delete_temp_list.append(delete_distance)

        if len(add_temp_list) > 1:
            add_list.append(np.mean(add_temp_list))
            delete_list.append(np.mean(delete_temp_list))
        elif len(add_temp_list) == 1:
            add_list.append(add_temp_list[0])
            delete_list.append(delete_temp_list[0])

        smm_record.append(y_pred_label)
        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(
            np.array(y_gt), np.array(y_pred), np.array(y_pred_prob))

        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rtest step: {} / {}'.format(step, len(data_eval)))

    # ddi rate
    ddi_rate = ddi_rate_score(smm_record,
                              path='../data/output/ddi_A_final.pkl')

    llprint(
        '\nDDI Rate: {:.4}, Jaccard: {:.4},  AVG_F1: {:.4}, Add: {:.4}, Delete: {:.4}, AVG_MED: {:.4}\n'
        .format(ddi_rate, np.mean(ja), np.mean(avg_f1), np.mean(add_list),
                np.mean(delete_list), med_cnt / visit_cnt))

    return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(
        avg_r), np.mean(avg_f1), np.mean(add_list), np.mean(
            delete_list), med_cnt / visit_cnt
Ejemplo n.º 8
0
def main():
    # grid_search = False
    data_path = '../data/output/records_final.pkl'
    voc_path = '../data/output/voc_final.pkl'

    data = dill.load(open(data_path, 'rb'))
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']

    for epoch in range(1):

        np.random.seed(epoch)
        np.random.shuffle(data)
        split_point = int(len(data) * 2 / 3)
        data_train = data[:split_point]
        eval_len = int(len(data[split_point:]) / 2)
        data_eval = data[split_point+eval_len:]
        data_test = data[split_point:split_point + eval_len]

        train_X, train_y = create_dataset(data_train, diag_voc, pro_voc, med_voc)
        test_X, test_y = create_dataset(data_test, diag_voc, pro_voc, med_voc)
        eval_X, eval_y = create_dataset(data_eval, diag_voc, pro_voc, med_voc)
        model = LogisticRegression()
        classifier = OneVsRestClassifier(model)

        tic = time.time()
        classifier.fit(train_X, train_y)

        fittime = time.time() - tic
        print ('fitting time: {}'.format(fittime))


        result = []
        for _ in range(10):
            index = np.random.choice(np.arange(len(test_X)), round(len(test_X) * 0.8), replace=True)
            test_sample = test_X[index]
            y_sample = test_y[index]
            y_pred = classifier.predict(test_sample)
            pretime = time.time() - tic
            print ('inference time: {}'.format(pretime))

            y_prob = classifier.predict_proba(test_sample)

            ja, prauc, avg_p, avg_r, avg_f1 = multi_label_metric(y_sample, y_pred, y_prob)

            # ddi rate
            ddi_A = dill.load(open('../data/output/ddi_A_final.pkl', 'rb'))
            all_cnt = 0
            dd_cnt = 0
            med_cnt = 0
            visit_cnt = 0
            for adm in y_pred:
                med_code_set = np.where(adm==1)[0]
                visit_cnt += 1
                med_cnt += len(med_code_set)
                for i, med_i in enumerate(med_code_set):
                    for j, med_j in enumerate(med_code_set):
                        if j <= i:
                            continue
                        all_cnt += 1
                        if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1:
                            dd_cnt += 1
            ddi_rate = dd_cnt / all_cnt
            result.append([ddi_rate, ja, avg_f1, prauc, med_cnt / visit_cnt])
        
        result = np.array(result)
        mean = result.mean(axis=0)
        std = result.std(axis=0)

        outstring = ""
        for m, s in zip(mean, std):
            outstring += "{:.4f} $\pm$ {:.4f} & ".format(m, s)

        print (outstring)

        tic = time.time()
        print('Epoch: {}, DDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n'.format(
            epoch, ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, med_cnt / visit_cnt
            ))

        history = defaultdict(list)
        history['fittime'].append(fittime)
        history['pretime'].append(pretime)
        history['jaccard'].append(ja)
        history['ddi_rate'].append(ddi_rate)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)

    dill.dump(history, open(os.path.join('saved', model_name, 'history.pkl'), 'wb'))
    print('Avg_Fittime: {:.8}, Avg_Pretime: {:.8}, Avg_Jaccard: {:.4}, Avg_DDI: {:.4}, Avg_p: {:.4}, Avg_r: {:.4}, \
            Avg_f1: {:.4}, AVG_PRC: {:.4}\n'.format(
        np.mean(history['fittime']),
        np.mean(history['pretime']),
        np.mean(history['jaccard']),
        np.mean(history['ddi_rate']),
        np.mean(history['avg_p']),
        np.mean(history['avg_r']),
        np.mean(history['avg_f1']),
        np.mean(history['prauc'])
        ))
Ejemplo n.º 9
0
def main():
    # grid_search = False
    data_path = '../data/output/records_final.pkl'
    voc_path = '../data/output/voc_final.pkl'

    data = dill.load(open(data_path, 'rb'))
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[
        'med_voc']

    epoch = 100

    np.random.seed(epoch)
    np.random.shuffle(data)
    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_eval = data[split_point + eval_len:]
    data_test = data[split_point:split_point + eval_len]

    train_X, train_y = create_dataset(data_train, diag_voc, pro_voc, med_voc)
    test_X, test_y = create_dataset(data_test, diag_voc, pro_voc, med_voc)
    eval_X, eval_y = create_dataset(data_eval, diag_voc, pro_voc, med_voc)
    """
    some drugs do not appear in the train set (their index is non_appear_idx)
    we omit them during training ECC (resulting in appear_idx)
    and directly not recommend these for test and eval
    """
    # non_appear_idx = np.where(train_y.sum(axis=0) == 0)[0]
    appear_idx = np.where(train_y.sum(axis=0) > 0)[0]
    train_y = train_y[:, appear_idx]

    base_dt = LogisticRegression()

    tic_total_fit = time.time()
    global chains
    chains = [
        ClassifierChain(base_dt, order='random', random_state=i)
        for i in range(10)
    ]
    for i, chain in enumerate(chains):
        tic = time.time()
        chain.fit(train_X, train_y)
        fittime = time.time() - tic
        print('id {}, fitting time: {}'.format(i, fittime))
    print('total fitting time: {}'.format(time.time() - tic_total_fit))

    # exit()

    tic = time.time()
    y_pred_chains = np.array(
        [augment(chain.predict(test_X), appear_idx) for chain in chains])
    y_prob_chains = np.array(
        [augment(chain.predict_proba(test_X), appear_idx) for chain in chains])
    pretime = time.time() - tic
    print('inference time: {}'.format(pretime))

    y_pred = y_pred_chains.mean(axis=0)
    y_pred[y_pred >= 0.5] = 1
    y_pred[y_pred < 0.5] = 0
    y_prob = y_prob_chains.mean(axis=0)

    ja, prauc, avg_p, avg_r, avg_f1 = multi_label_metric(
        test_y, y_pred, y_prob)

    # ddi rate
    ddi_A = dill.load(open('../data/output/ddi_A_final.pkl', 'rb'))
    all_cnt = 0
    dd_cnt = 0
    med_cnt = 0
    visit_cnt = 0
    for adm in y_pred:
        med_code_set = np.where(adm == 1)[0]
        visit_cnt += 1
        med_cnt += len(med_code_set)
        for i, med_i in enumerate(med_code_set):
            for j, med_j in enumerate(med_code_set):
                if j <= i:
                    continue
                all_cnt += 1
                if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1:
                    dd_cnt += 1
    ddi_rate = dd_cnt / all_cnt
    print(
        'Epoch: {}, DDI Rate: {:.4}, Jaccard: {:.4}, PRAUC: {:.4}, AVG_PRC: {:.4}, AVG_RECALL: {:.4}, AVG_F1: {:.4}, AVG_MED: {:.4}\n'
        .format(epoch, ddi_rate, ja, prauc, avg_p, avg_r, avg_f1,
                med_cnt / visit_cnt))
Ejemplo n.º 10
0
def main():
    grid_search = False
    data_path = '../../data/records_final.pkl'
    voc_path = '../../data/voc_final.pkl'

    data = dill.load(open(data_path, 'rb'))
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[
        'med_voc']

    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_eval = data[split_point + eval_len:]
    data_test = data[split_point:split_point + eval_len]

    train_X, train_y = create_dataset(data_train, diag_voc, pro_voc, med_voc)
    test_X, test_y = create_dataset(data_test, diag_voc, pro_voc, med_voc)
    eval_X, eval_y = create_dataset(data_eval, diag_voc, pro_voc, med_voc)

    if grid_search:
        params = {
            'estimator__penalty': ['l2'],
            'estimator__C': np.linspace(0.00002, 1, 100)
        }

        model = LogisticRegression()
        classifier = OneVsRestClassifier(model)
        lr_gs = GridSearchCV(classifier, params,
                             verbose=1).fit(train_X, train_y)

        print("Best Params", lr_gs.best_params_)
        print("Best Score", lr_gs.best_score_)

        return

    # sample_X, sample_y = create_dataset(sample_data, diag_voc, pro_voc, med_voc)

    model = LogisticRegression(C=0.90909)
    classifier = OneVsRestClassifier(model)
    classifier.fit(train_X, train_y)

    y_pred = classifier.predict(test_X)
    y_prob = classifier.predict_proba(test_X)

    ja, prauc, avg_p, avg_r, avg_f1 = multi_label_metric(
        test_y, y_pred, y_prob)

    # ddi rate
    ddi_A = dill.load(open('../../data/ddi_A_final.pkl', 'rb'))
    all_cnt = 0
    dd_cnt = 0
    med_cnt = 0
    visit_cnt = 0
    for adm in y_pred:
        med_code_set = np.where(adm == 1)[0]
        visit_cnt += 1
        med_cnt += len(med_code_set)
        for i, med_i in enumerate(med_code_set):
            for j, med_j in enumerate(med_code_set):
                if j <= i:
                    continue
                all_cnt += 1
                if ddi_A[med_i, med_j] == 1 or ddi_A[med_j, med_i] == 1:
                    dd_cnt += 1
    ddi_rate = dd_cnt / all_cnt
    print(
        '\tDDI Rate: %.4f, Jaccard: %.4f, PRAUC: %.4f, AVG_PRC: %.4f, AVG_RECALL: %.4f, AVG_F1: %.4f\n'
        % (ddi_rate, ja, prauc, avg_p, avg_r, avg_f1))

    history = defaultdict(list)
    for i in range(30):
        history['jaccard'].append(ja)
        history['ddi_rate'].append(ddi_rate)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)

    dill.dump(history,
              open(os.path.join('saved', model_name, 'history.pkl'), 'wb'))

    print('avg med', med_cnt / visit_cnt)
Ejemplo n.º 11
0
def eval(model, data_eval, voc_size, epoch):
    model.eval()

    smm_record = []
    ja, prauc, avg_p, avg_r, avg_f1 = [[] for _ in range(5)]
    med_cnt, visit_cnt = 0, 0
    add_list, delete_list = [], []

    for step, input in enumerate(data_eval):
        y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], []
        if len(input) < 2: continue
        add_temp_list, delete_temp_list = [], []

        previous_set = input[0][2]
        for i in range(1, len(input)):
            target_output = model(input[:i])

            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[input[i][2]] = 1
            y_gt.append(y_gt_tmp)

            # prediction prob
            target_output = F.sigmoid(target_output).detach().cpu().numpy()[0]
            y_pred_prob.append(target_output)

            # prediction med set
            y_pred_tmp = target_output.copy()
            y_pred_tmp[y_pred_tmp >= 0.3] = 1
            y_pred_tmp[y_pred_tmp < 0.3] = 0
            y_pred.append(y_pred_tmp)

            # prediction label
            y_pred_label_tmp = np.where(y_pred_tmp == 1)[0]
            y_pred_label.append(y_pred_label_tmp)
            med_cnt += len(y_pred_label_tmp)
            visit_cnt += 1

            #### add or delete
            add_gt = set(np.where(y_gt_tmp == 1)[0]) - set(previous_set)
            delete_gt = set(previous_set) - set(np.where(y_gt_tmp == 1)[0])

            add_pre = set(np.where(y_pred_tmp == 1)[0]) - set(previous_set)
            delete_pre = set(previous_set) - set(np.where(y_pred_tmp == 1)[0])

            add_distance = len(set(add_pre) -
                               set(add_gt)) + len(set(add_gt) - set(add_pre))
            delete_distance = len(set(delete_pre) - set(delete_gt)) + len(
                set(delete_gt) - set(delete_pre))
            ####

            add_temp_list.append(add_distance)
            delete_temp_list.append(delete_distance)

            previous_set = y_pred_label_tmp

        if len(add_temp_list) > 1:
            add_list.append(np.mean(add_temp_list))
            delete_list.append(np.mean(delete_temp_list))
        elif len(add_temp_list) == 1:
            add_list.append(add_temp_list[0])
            delete_list.append(delete_temp_list[0])

        smm_record.append(y_pred_label)
        adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1 = multi_label_metric(
            np.array(y_gt), np.array(y_pred), np.array(y_pred_prob))

        ja.append(adm_ja)
        prauc.append(adm_prauc)
        avg_p.append(adm_avg_p)
        avg_r.append(adm_avg_r)
        avg_f1.append(adm_avg_f1)
        llprint('\rtest step: {} / {}'.format(step, len(data_eval)))

    # ddi rate
    ddi_rate = ddi_rate_score(smm_record,
                              path='../data/output/ddi_A_final.pkl')

    llprint(
        '\nDDI Rate: {:.4}, Jaccard: {:.4},  AVG_F1: {:.4}, Add: {:.4}, Delete: {:.4}, AVG_MED: {:.4}\n'
        .format(np.float(ddi_rate), np.mean(ja), np.mean(avg_f1),
                np.mean(add_list), np.mean(delete_list), med_cnt / visit_cnt))

    return np.float(ddi_rate), np.mean(ja), np.mean(prauc), np.mean(
        avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean(add_list), np.mean(
            delete_list), med_cnt / visit_cnt