예제 #1
0
파일: nn_ops.py 프로젝트: yuguo68/pytorch
 def forward(self):
     a = torch.randn(3, 2)
     b = torch.rand(3, 2)
     c = torch.rand(3)
     log_probs = torch.randn(50, 16, 20).log_softmax(2).detach()
     targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
     input_lengths = torch.full((16, ), 50, dtype=torch.long)
     target_lengths = torch.randint(10, 30, (16, ), dtype=torch.long)
     return len(
         F.binary_cross_entropy(torch.sigmoid(a), b),
         F.binary_cross_entropy_with_logits(torch.sigmoid(a), b),
         F.poisson_nll_loss(a, b),
         F.cosine_embedding_loss(a, b, c),
         F.cross_entropy(a, b),
         F.ctc_loss(log_probs, targets, input_lengths, target_lengths),
         # F.gaussian_nll_loss(a, b, torch.ones(5, 1)), # ENTER is not supported in mobile module
         F.hinge_embedding_loss(a, b),
         F.kl_div(a, b),
         F.l1_loss(a, b),
         F.mse_loss(a, b),
         F.margin_ranking_loss(c, c, c),
         F.multilabel_margin_loss(self.x, self.y),
         F.multilabel_soft_margin_loss(self.x, self.y),
         F.multi_margin_loss(self.x, torch.tensor([3])),
         F.nll_loss(a, torch.tensor([1, 0, 1])),
         F.huber_loss(a, b),
         F.smooth_l1_loss(a, b),
         F.soft_margin_loss(a, b),
         F.triplet_margin_loss(a, b, -b),
         # F.triplet_margin_with_distance_loss(a, b, -b), # can't take variable number of arguments
     )
예제 #2
0
    def forward(self, x, y, reduction='mean'):
        """
            y: labels have standard {0,1} form and will be converted to indices
        """
        b, c = x.size()
        idx = (torch.arange(c) + 1).type_as(x)
        y_idx, _ = (idx * y).sort(-1, descending=True)
        y_idx = (y_idx - 1).long()

        return F.multilabel_margin_loss(x, y_idx, reduction=reduction)
예제 #3
0
    def get_loss(self, y_pred, y_true, **kwargs):

        loss_bce_target = np.zeros((1, MEDICATION_COUNT))
        loss_bce_target[:, y_true] = 1
        loss_multi_target = np.full((1, MEDICATION_COUNT), -1)
        for idx, item in enumerate(y_true):
            loss_multi_target[0][idx] = item

        loss_bce = F.binary_cross_entropy_with_logits(y_pred, torch.FloatTensor(loss_bce_target).to(self.device))
        loss_multi = F.multilabel_margin_loss(torch.sigmoid(y_pred),
                                              torch.LongTensor(loss_multi_target).to(self.device))
        loss = LOSS_PROPORTION_BCE * loss_bce + LOSS_PROPORTION_MULTI * loss_multi
        return loss
예제 #4
0
 def test_multilabel_margin_loss(self):
     inp = torch.randn(1024,
                       device='cuda',
                       dtype=self.dtype,
                       requires_grad=True)
     target = torch.randint(0,
                            10, (1024, ),
                            dtype=torch.long,
                            device='cuda')
     output = F.multilabel_margin_loss(inp,
                                       target,
                                       size_average=None,
                                       reduce=None,
                                       reduction='mean')
def train(epoch):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = network(data)
    loss = F.multilabel_margin_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
def test():
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.multilabel_margin_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))
    def configure_criterion(self, y, t):

        criterion = F.cross_entropy(y, t)

        if self.hparams.criterion == "cross_entropy":
            criterion = F.cross_entropy(y, t)
        elif self.hparams.criterion == "binary_cross_entropy":
            criterion = F.binary_cross_entropy(y, t)
        elif self.hparams.criterion == "binary_cross_entropy_with_logits":
            criterion = F.binary_cross_entropy_with_logits(y, t)
        elif self.hparams.criterion == "poisson_nll_loss":
            criterion = F.poisson_nll_loss(y, t)
        elif self.hparams.criterion == "hinge_embedding_loss":
            criterion = F.hinge_embedding_loss(y, t)
        elif self.hparams.criterion == "kl_div":
            criterion = F.kl_div(y, t)
        elif self.hparams.criterion == "l1_loss":
            criterion = F.l1_loss(y, t)
        elif self.hparams.criterion == "mse_loss":
            criterion = F.mse_loss(y, t)
        elif self.hparams.criterion == "margin_ranking_loss":
            criterion = F.margin_ranking_loss(y, t)
        elif self.hparams.criterion == "multilabel_margin_loss":
            criterion = F.multilabel_margin_loss(y, t)
        elif self.hparams.criterion == "multilabel_soft_margin_loss":
            criterion = F.multilabel_soft_margin_loss(y, t)
        elif self.hparams.criterion == "multi_margin_loss":
            criterion = F.multi_margin_loss(y, t)
        elif self.hparams.criterion == "nll_loss":
            criterion = F.nll_loss(y, t)
        elif self.hparams.criterion == "smooth_l1_loss":
            criterion = F.smooth_l1_loss(y, t)
        elif self.hparams.criterion == "soft_margin_loss":
            criterion = F.soft_margin_loss(y, t)

        return criterion
예제 #8
0
def main():
    if not os.path.exists(os.path.join("saved", model_name)):
        os.makedirs(os.path.join("saved", model_name))

    data_path = '../data/records.pkl'
    voc_path = '../data/voc.pkl'
    ehr_adj_path = '../data/ehr_adj.pkl'
    ddi_adj_path = '../data/ddi_A.pkl'
    device = torch.device('cuda:0')

    ehr_adj = dill.load(open(ehr_adj_path, 'rb'))
    ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
    data = dill.load(open(data_path, 'rb'))
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[
        'med_voc']

    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    # data_eval = data[split_point:split_point + eval_len]
    data_eval = data[split_point + eval_len:]

    EPOCH = 30
    LR = 0.001
    EVAL = True

    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word),
                len(med_voc.idx2word))
    model = GMNN(voc_size, ehr_adj, ddi_adj, emb_dim=64, device=device)
    if EVAL:
        model.load_state_dict(
            torch.load(
                open(os.path.join("saved", model_name, resume_name), 'rb')))
    model.to(device=device)

    optimizer = Adam(list(model.parameters()), lr=LR)

    if EVAL:
        eval(model, data_eval, voc_size, 0)
    else:
        for epoch in range(EPOCH):
            loss_record1 = []
            loss_record2 = []
            start_time = time.time()
            model.train()
            for step, input in enumerate(data_train):
                input1_hidden, input2_hidden, target_hidden = None, None, None
                loss = 0
                for adm in input:
                    loss1_target = np.zeros((1, voc_size[2]))
                    loss1_target[:, adm[2]] = 1

                    loss2_target = adm[2] + [adm[2][0]]

                    loss3_target = np.full((1, voc_size[2]), -1)
                    for idx, item in enumerate(adm[2]):
                        loss3_target[0][idx] = item

                    target_output1, target_output2, [
                        input1_hidden, input2_hidden, target_hidden
                    ], batch_pos_loss, batch_neg_loss = model(
                        adm, [input1_hidden, input2_hidden, target_hidden])

                    loss1 = F.binary_cross_entropy_with_logits(
                        target_output1,
                        torch.FloatTensor(loss1_target).to(device))
                    loss2 = F.cross_entropy(
                        target_output2,
                        torch.LongTensor(loss2_target).to(device))

                    # loss = 9*loss1/10 + loss2/10
                    loss3 = F.multilabel_margin_loss(
                        F.sigmoid(target_output1),
                        torch.LongTensor(loss3_target).to(device))
                    loss += loss1 + 0.1 * loss3 + 0.01 * batch_neg_loss

                    loss_record1.append(loss.item())
                    loss_record2.append(loss3.item())

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                llprint('\rTrain--Epoch: %d, Step: %d/%d' %
                        (epoch, step, len(data_train)))

            eval(model, data_eval, voc_size, epoch)

            end_time = time.time()
            elapsed_time = (end_time - start_time) / 60
            llprint(
                '\tEpoch: %d, Loss1: %.4f, Loss2: %.4f, One Epoch Time: %.2fm, Appro Left Time: %.2fh\n'
                % (epoch, np.mean(loss_record1), np.mean(loss_record2),
                   elapsed_time, elapsed_time * (EPOCH - epoch - 1) / 60))

            torch.save(
                model.state_dict(),
                open(
                    os.path.join(
                        'saved', model_name, 'Epoch_%d_Loss1_%.4f.model' %
                        (epoch, np.mean(loss_record1))), 'wb'))
            print('')

        # test
        torch.save(
            model.state_dict(),
            open(os.path.join('saved', model_name, 'final.model'), 'wb'))
예제 #9
0
def multilabel_margin(y_pred, y_true):
    return F.multilabel_margin_loss(y_pred, y_true)
예제 #10
0
파일: MICRON.py 프로젝트: ycq091044/MICRON
def main():

    # load data
    data_path = '../data/output/records_final.pkl'
    voc_path = '../data/output/voc_final.pkl'

    ddi_adj_path = '../data/output/ddi_A_final.pkl'
    device = torch.device('cuda:{}'.format(args.cuda))

    ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
    data = dill.load(open(data_path, 'rb'))

    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[
        'med_voc']

    np.random.seed(1203)
    np.random.shuffle(data)

    split_point = int(len(data) * 3 / 5)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point + eval_len:]

    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word),
                len(med_voc.idx2word))

    model = MICRON(voc_size, ddi_adj, emb_dim=args.dim, device=device)
    # model.load_state_dict(torch.load(open(args.resume_path, 'rb')))

    if args.Test:
        model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
        model.to(device=device)
        tic = time.time()
        label_list, prob_list = eval(model, data_eval, voc_size, 0, 1)

        threshold1, threshold2 = [], []
        for i in range(label_list.shape[1]):
            _, _, boundary = roc_curve(label_list[:, i],
                                       prob_list[:, i],
                                       pos_label=1)
            # boundary1 should be in [0.5, 0.9], boundary2 should be in [0.1, 0.5]
            threshold1.append(
                min(
                    0.9,
                    max(0.5, boundary[max(0,
                                          round(len(boundary) * 0.05) - 1)])))
            threshold2.append(
                max(
                    0.1,
                    min(
                        0.5, boundary[min(round(len(boundary) * 0.95),
                                          len(boundary) - 1)])))
        print(np.mean(threshold1), np.mean(threshold2))
        threshold1 = np.ones(voc_size[2]) * np.mean(threshold1)
        threshold2 = np.ones(voc_size[2]) * np.mean(threshold2)
        eval(model, data_test, voc_size, 0, 0, threshold1, threshold2)
        print('test time: {}'.format(time.time() - tic))

        return

    model.to(device=device)
    print('parameters', get_n_params(model))
    # exit()
    optimizer = RMSprop(list(model.parameters()),
                        lr=args.lr,
                        weight_decay=args.weight_decay)

    # start iterations
    history = defaultdict(list)
    best_epoch, best_ja = 0, 0

    weight_list = [[0.25, 0.25, 0.25, 0.25]]

    EPOCH = 40
    for epoch in range(EPOCH):
        t = 0
        tic = time.time()
        print('\nepoch {} --------------------------'.format(epoch + 1))

        sample_counter = 0
        mean_loss = np.array([0, 0, 0, 0])

        model.train()
        for step, input in enumerate(data_train):
            loss = 0
            if len(input) < 2: continue
            for adm_idx, adm in enumerate(input):
                if adm_idx == 0: continue
                # sample_counter += 1
                seq_input = input[:adm_idx + 1]

                loss_bce_target = np.zeros((1, voc_size[2]))
                loss_bce_target[:, adm[2]] = 1

                loss_bce_target_last = np.zeros((1, voc_size[2]))
                loss_bce_target_last[:, input[adm_idx - 1][2]] = 1

                loss_multi_target = np.full((1, voc_size[2]), -1)
                for idx, item in enumerate(adm[2]):
                    loss_multi_target[0][idx] = item

                loss_multi_target_last = np.full((1, voc_size[2]), -1)
                for idx, item in enumerate(input[adm_idx - 1][2]):
                    loss_multi_target_last[0][idx] = item

                result, result_last, _, loss_ddi, loss_rec = model(seq_input)

                loss_bce = 0.75 * F.binary_cross_entropy_with_logits(result, torch.FloatTensor(loss_bce_target).to(device)) + \
                    (1 - 0.75) * F.binary_cross_entropy_with_logits(result_last, torch.FloatTensor(loss_bce_target_last).to(device))
                loss_multi = 5e-2 * (0.75 * F.multilabel_margin_loss(F.sigmoid(result), torch.LongTensor(loss_multi_target).to(device)) + \
                    (1 - 0.75) * F.multilabel_margin_loss(F.sigmoid(result_last), torch.LongTensor(loss_multi_target_last).to(device)))

                y_pred_tmp = F.sigmoid(result).detach().cpu().numpy()[0]
                y_pred_tmp[y_pred_tmp >= 0.5] = 1
                y_pred_tmp[y_pred_tmp < 0.5] = 0
                y_label = np.where(y_pred_tmp == 1)[0]
                current_ddi_rate = ddi_rate_score(
                    [[y_label]], path='../data/output/ddi_A_final.pkl')

                # l2 = 0
                # for p in model.parameters():
                #     l2 = l2 + (p ** 2).sum()

                if sample_counter == 0:
                    lambda1, lambda2, lambda3, lambda4 = weight_list[-1]
                else:
                    current_loss = np.array([
                        loss_bce.detach().cpu().numpy(),
                        loss_multi.detach().cpu().numpy(),
                        loss_ddi.detach().cpu().numpy(),
                        loss_rec.detach().cpu().numpy()
                    ])
                    current_ratio = (current_loss -
                                     np.array(mean_loss)) / np.array(mean_loss)
                    instant_weight = np.exp(current_ratio) / sum(
                        np.exp(current_ratio))
                    lambda1, lambda2, lambda3, lambda4 = instant_weight * 0.75 + np.array(
                        weight_list[-1]) * 0.25
                    # update weight_list
                    weight_list.append([lambda1, lambda2, lambda3, lambda4])
                # update mean_loss
                mean_loss = (mean_loss * (sample_counter - 1) + np.array([loss_bce.detach().cpu().numpy(), \
                    loss_multi.detach().cpu().numpy(), loss_ddi.detach().cpu().numpy(), loss_rec.detach().cpu().numpy()])) / sample_counter
                # lambda1, lambda2, lambda3, lambda4 = weight_list[-1]
                if current_ddi_rate > 0.08:
                    loss += lambda1 * loss_bce + lambda2 * loss_multi + \
                                 lambda3 * loss_ddi +  lambda4 * loss_rec
                else:
                    loss += lambda1 * loss_bce + lambda2 * loss_multi + \
                                lambda4 * loss_rec

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

            llprint('\rtraining step: {} / {}'.format(step, len(data_train)))

        tic2 = time.time()
        ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, add, delete, avg_med = eval(
            model, data_eval, voc_size, epoch)
        print('training time: {}, test time: {}'.format(
            time.time() - tic,
            time.time() - tic2))

        history['ja'].append(ja)
        history['ddi_rate'].append(ddi_rate)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)
        history['add'].append(add)
        history['delete'].append(delete)
        history['med'].append(avg_med)

        if epoch >= 5:
            print(
                'ddi: {}, Med: {}, Ja: {}, F1: {}, Add: {}, Delete: {}'.format(
                    np.mean(history['ddi_rate'][-5:]),
                    np.mean(history['med'][-5:]), np.mean(history['ja'][-5:]),
                    np.mean(history['avg_f1'][-5:]),
                    np.mean(history['add'][-5:]),
                    np.mean(history['delete'][-5:])))

        torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \
            'Epoch_{}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, ja, ddi_rate)), 'wb'))

        if epoch != 0 and best_ja < ja:
            best_epoch = epoch
            best_ja = ja

        print('best_epoch: {}'.format(best_epoch))

    dill.dump(
        history,
        open(
            os.path.join('saved', args.model_name,
                         'history_{}.pkl'.format(args.model_name)), 'wb'))
예제 #11
0
def main():

    # load data
    data_path = '../data/output/records_final.pkl'
    voc_path = '../data/output/voc_final.pkl'

    ddi_adj_path = '../data/output/ddi_A_final.pkl'
    device = torch.device('cuda:{}'.format(args.cuda))

    ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
    data = dill.load(open(data_path, 'rb'))

    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[
        'med_voc']

    np.random.seed(1203)
    np.random.shuffle(data)

    split_point = int(len(data) * 3 / 5)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point + eval_len:]

    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word),
                len(med_voc.idx2word))
    print(voc_size)
    model = DualNN(voc_size, ddi_adj, emb_dim=args.dim, device=device)
    # model.load_state_dict(torch.load(open(args.resume_path, 'rb')))

    if args.Test:
        model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
        model.to(device=device)
        tic = time.time()
        label_list, prob_add, prob_delete = eval(model, data_eval, voc_size, 0,
                                                 1)

        threshold1, threshold2 = [], []
        for i in range(label_list.shape[1]):
            _, _, boundary_add = roc_curve(label_list[:, i],
                                           prob_add[:, i],
                                           pos_label=1)
            _, _, boundary_delete = roc_curve(label_list[:, i],
                                              prob_delete[:, i],
                                              pos_label=0)
            threshold1.append(boundary_add[min(round(len(boundary_add) * 0.05),
                                               len(boundary_add) - 1)])
            threshold2.append(boundary_delete[min(
                round(len(boundary_delete) * 0.05),
                len(boundary_delete) - 1)])
        # threshold1 = np.ones(voc_size[2]) * np.mean(threshold1)
        # threshold2 = np.ones(voc_size[2]) * np.mean(threshold2)
        print(np.mean(threshold1), np.mean(threshold2))
        eval(model, data_test, voc_size, 0, 0, threshold1, threshold2)
        print('test time: {}'.format(time.time() - tic))

        return

    model.to(device=device)
    print('parameters', get_n_params(model))
    # exit()
    optimizer = RMSprop(list(model.parameters()),
                        lr=args.lr,
                        weight_decay=args.weight_decay)

    # start iterations
    history = defaultdict(list)
    best_epoch, best_ja = 0, 0

    EPOCH = 40
    for epoch in range(EPOCH):
        t = 0
        tic = time.time()
        print('\nepoch {} --------------------------'.format(epoch + 1))

        model.train()
        for step, input in enumerate(data_train):
            if len(input) < 2: continue
            loss = 0
            for adm_idx, adm in enumerate(input):
                if adm_idx == 0: continue

                seq_input = input[:adm_idx + 1]

                loss_bce_target = np.zeros((1, voc_size[2]))
                loss_bce_target[:, adm[2]] = 1

                loss_bce_target_last = np.zeros((1, voc_size[2]))
                loss_bce_target_last[:, input[adm_idx - 1][2]] = 1

                add_target = np.zeros((1, voc_size[2]))
                add_target[:, np.where(loss_bce_target == 1)[1]] = 1
                delete_target = np.zeros((1, voc_size[2]))
                delete_target[:, np.where(loss_bce_target == 0)[1]] = 1

                loss_multi_target = np.full((1, voc_size[2]), -1)
                for idx, item in enumerate(adm[2]):
                    loss_multi_target[0][idx] = item

                loss_multi_target_last = np.full((1, voc_size[2]), -1)
                for idx, item in enumerate(input[adm_idx - 1][2]):
                    loss_multi_target_last[0][idx] = item

                loss_multi_add_target = np.full((1, voc_size[2]), -1)
                for i, item in enumerate(np.where(add_target == 1)[0]):
                    loss_multi_add_target[0][i] = item

                loss_multi_delete_target = np.full((1, voc_size[2]), -1)
                for i, item in enumerate(np.where(delete_target == 1)[0]):
                    loss_multi_delete_target[0][i] = item

                add_result, delete_result = model(seq_input)

                loss_bce = F.binary_cross_entropy_with_logits(add_result, torch.FloatTensor(add_target).to(device)) + \
                    F.binary_cross_entropy_with_logits(delete_result, torch.FloatTensor(delete_target).to(device))
                loss_multi = F.multilabel_margin_loss(F.sigmoid(add_result), torch.LongTensor(loss_multi_add_target).to(device)) + \
                    F.multilabel_margin_loss(F.sigmoid(delete_result), torch.LongTensor(loss_multi_delete_target).to(device))

                # l2 = 0
                # for p in model.parameters():
                #     l2 = l2 + (p ** 2).sum()

                loss += 0.95 * loss_bce + 0.05 * loss_multi

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

            llprint('\rtraining step: {} / {}'.format(step, len(data_train)))

        print()
        tic2 = time.time()
        ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, add, delete, avg_med = eval(
            model, data_eval, voc_size, epoch)
        print('training time: {}, test time: {}'.format(
            time.time() - tic,
            time.time() - tic2))

        history['ja'].append(ja)
        history['ddi_rate'].append(ddi_rate)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)
        history['add'].append(add)
        history['delete'].append(delete)
        history['med'].append(avg_med)

        if epoch >= 5:
            print(
                'ddi: {}, Med: {}, Ja: {}, F1: {}, Add: {}, Delete: {}'.format(
                    np.mean(history['ddi_rate'][-5:]),
                    np.mean(history['med'][-5:]), np.mean(history['ja'][-5:]),
                    np.mean(history['avg_f1'][-5:]),
                    np.mean(history['add'][-5:]),
                    np.mean(history['delete'][-5:])))

        torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \
            'Epoch_{}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, ja, ddi_rate)), 'wb'))

        if epoch != 0 and best_ja < ja:
            best_epoch = epoch
            best_ja = ja

        print('best_epoch: {}'.format(best_epoch))

    dill.dump(
        history,
        open(
            os.path.join('saved', args.model_name,
                         'history_{}.pkl'.format(args.model_name)), 'wb'))
예제 #12
0
def main():

    data_path = '../data/output/records_final.pkl'
    voc_path = '../data/output/voc_final.pkl'

    ehr_adj_path = '../data/output/ehr_adj_final.pkl'
    ddi_adj_path = '../data/output/ddi_A_final.pkl'
    device = torch.device('cuda:{}'.format(args.cuda))

    ehr_adj = dill.load(open(ehr_adj_path, 'rb'))
    ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
    data = dill.load(open(data_path, 'rb'))

    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[
        'med_voc']

    # np.random.seed(2048)
    # np.random.shuffle(data)
    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point + eval_len:]

    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word),
                len(med_voc.idx2word))
    model = GAMENet(voc_size,
                    ehr_adj,
                    ddi_adj,
                    emb_dim=args.dim,
                    device=device,
                    ddi_in_memory=args.ddi)
    # model.load_state_dict(torch.load(open(args.resume_path, 'rb')))

    if args.Test:
        model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
        model.to(device=device)
        tic = time.time()
        result = []
        for _ in range(10):
            test_sample = np.random.choice(data_test,
                                           round(len(data_test) * 0.8),
                                           replace=True)
            ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(
                model, test_sample, voc_size, 0)
            result.append([ddi_rate, ja, avg_f1, prauc, avg_med])

        result = np.array(result)
        mean = result.mean(axis=0)
        std = result.std(axis=0)

        outstring = ""
        for m, s in zip(mean, std):
            outstring += "{:.4f} $\pm$ {:.4f} & ".format(m, s)

        print(outstring)
        print('test time: {}'.format(time.time() - tic))
        return

    model.to(device=device)
    print('parameters', get_n_params(model))
    optimizer = Adam(list(model.parameters()), lr=args.lr)

    history = defaultdict(list)
    best_epoch, best_ja = 0, 0

    EPOCH = 50
    for epoch in range(EPOCH):
        tic = time.time()
        print('\nepoch {} --------------------------'.format(epoch + 1))
        prediction_loss_cnt, neg_loss_cnt = 0, 0
        model.train()
        for step, input in enumerate(data_train):
            for idx, adm in enumerate(input):
                seq_input = input[:idx + 1]
                loss_bce_target = np.zeros((1, voc_size[2]))
                loss_bce_target[:, adm[2]] = 1

                loss_multi_target = np.full((1, voc_size[2]), -1)
                for idx, item in enumerate(adm[2]):
                    loss_multi_target[0][idx] = item

                target_output1, loss_ddi = model(seq_input)

                loss_bce = F.binary_cross_entropy_with_logits(
                    target_output1,
                    torch.FloatTensor(loss_bce_target).to(device))
                loss_multi = F.multilabel_margin_loss(
                    F.sigmoid(target_output1),
                    torch.LongTensor(loss_multi_target).to(device))
                if args.ddi:
                    target_output1 = F.sigmoid(
                        target_output1).detach().cpu().numpy()[0]
                    target_output1[target_output1 >= 0.5] = 1
                    target_output1[target_output1 < 0.5] = 0
                    y_label = np.where(target_output1 == 1)[0]
                    current_ddi_rate = ddi_rate_score(
                        [[y_label]], path='../data/output/ddi_A_final.pkl')
                    if current_ddi_rate <= args.target_ddi:
                        loss = 0.9 * loss_bce + 0.1 * loss_multi
                        prediction_loss_cnt += 1
                    else:
                        rnd = np.exp(
                            (args.target_ddi - current_ddi_rate) / args.T)
                        if np.random.rand(1) < rnd:
                            loss = loss_ddi
                            neg_loss_cnt += 1
                        else:
                            loss = 0.9 * loss_bce + 0.1 * loss_multi
                            prediction_loss_cnt += 1
                else:
                    loss = 0.9 * loss_bce + 0.1 * loss_multi

                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

            llprint('\rtraining step: {} / {}'.format(step, len(data_train)))

        args.T *= args.decay_weight

        print()
        tic2 = time.time()
        ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(
            model, data_eval, voc_size, epoch)
        print('training time: {}, test time: {}'.format(
            time.time() - tic,
            time.time() - tic2))

        history['ja'].append(ja)
        history['ddi_rate'].append(ddi_rate)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)
        history['med'].append(avg_med)

        if epoch >= 5:
            print('ddi: {}, Med: {}, Ja: {}, F1: {}, PRAUC: {}'.format(
                np.mean(history['ddi_rate'][-5:]),
                np.mean(history['med'][-5:]), np.mean(history['ja'][-5:]),
                np.mean(history['avg_f1'][-5:]),
                np.mean(history['prauc'][-5:])))

        torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \
            'Epoch_{}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, ja, ddi_rate)), 'wb'))

        if epoch != 0 and best_ja < ja:
            best_epoch = epoch
            best_ja = ja

        print('best_epoch: {}'.format(best_epoch))

    dill.dump(
        history,
        open(
            os.path.join('saved', args.model_name,
                         'history_{}.pkl'.format(args.model_name)), 'wb'))
    def forward(self, word_emb, char_index, text_len, speaker_ids, genre,
                is_training, gold_starts, gold_ends, cluster_ids):
        training_num = 0.0
        if is_training == 1:
            training_num = 1.0

        self.dropout = 1 - (training_num * self.config["dropout_rate"])  # 0.2
        self.lexical_dropout = 1 - (
            training_num * self.config["lexical_dropout_rate"])  # 0.5

        num_sentences = word_emb.shape[
            0]  # number of sentences to predict from
        max_sentence_length = word_emb.shape[
            1]  # maybe caused by applying padding to the dataset to have all sentences in the same shape

        text_emb_list = [word_emb]  # 3D tensor added in an array

        if self.config["char_embedding_size"] > 0:  # true is 8
            char_emb = torch.index_select(
                self.char_embeddings, 0,
                char_index.view(-1)).view(num_sentences, max_sentence_length,
                                          -1,
                                          self.config["char_embedding_size"])
            # [num_sentences, max_sentence_length, max_word_length, emb]
            # [a vector of embedding 8 for each character for each word for each sentence for all sentences]
            # (according to longest word and longest sentence)

            flattened_char_emb = char_emb.view([
                num_sentences * max_sentence_length,
                util.shape(char_emb, 2),
                util.shape(char_emb, 3)
            ])
            # [num_sentences * max_sentence_length, max_word_length, emb]

            flattened_aggregated_char_emb = self.char_cnn(flattened_char_emb)

            # [num_sentences * max_sentence_length, emb] character level CNN

            aggregated_char_emb = flattened_aggregated_char_emb.view([
                num_sentences, max_sentence_length,
                util.shape(flattened_aggregated_char_emb, 1)
            ])
            # [num_sentences, max_sentence_length, emb]
            text_emb_list.append(aggregated_char_emb)
        text_emb = torch.cat(text_emb_list, 2)
        text_emb = F.dropout(text_emb, self.lexical_dropout)

        text_len_mask = self.sequence_mask(text_len,
                                           max_len=max_sentence_length)
        text_len_mask = text_len_mask.view(num_sentences * max_sentence_length)

        text_outputs = self.encode_sentences(text_emb, text_len, text_len_mask)
        text_outputs = F.dropout(text_outputs, self.dropout)

        genre_emb = self.genre_tensor[genre]  # [emb]

        sentence_indices = torch.unsqueeze(torch.arange(num_sentences),
                                           1).repeat(1, max_sentence_length)
        # [num_sentences, max_sentence_length]

        # TODO make sure self.flatten_emb_by_sentence works as expected
        flattened_sentence_indices = self.flatten_emb_by_sentence(
            sentence_indices, text_len_mask)  # [num_words]
        flattened_text_emb = self.flatten_emb_by_sentence(
            text_emb, text_len_mask)  # [num_words]

        candidate_starts, candidate_ends = coref_ops.coref_kernels_spans(
            sentence_indices=flattened_sentence_indices,
            max_width=self.max_mention_width)

        candidate_mention_emb = self.get_mention_emb(
            flattened_text_emb, text_outputs, candidate_starts,
            candidate_ends)  # [num_candidates, emb]

        # this is now a nn candidate_mention_scores = self.get_mention_scores(candidate_mention_emb)  # [num_mentions, 1]
        candidate_mention_scores = self.mention(candidate_mention_emb)
        candidate_mention_scores = torch.squeeze(candidate_mention_scores,
                                                 1)  # [num_mentions]

        k = int(
            np.floor(
                float(text_outputs.shape[0]) * self.config["mention_ratio"]))
        predicted_mention_indices = coref_ops.coref_kernels_extract_mentions(
            candidate_mention_scores, candidate_starts, candidate_ends,
            k)  # ([k], [k])
        # predicted_mention_indices.set_shape([None])

        mention_starts = torch.index_select(
            candidate_starts, 0,
            predicted_mention_indices.type(torch.LongTensor))  # [num_mentions]
        mention_ends = torch.index_select(
            candidate_ends, 0,
            predicted_mention_indices.type(torch.LongTensor))  # [num_mentions]
        mention_emb = torch.index_select(
            candidate_mention_emb, 0,
            predicted_mention_indices.type(
                torch.LongTensor))  # [num_mentions, emb]
        mention_scores = torch.index_select(
            candidate_mention_scores, 0,
            predicted_mention_indices.type(torch.LongTensor))  # [num_mentions]

        mention_start_emb = torch.index_select(
            text_outputs, 0,
            mention_starts.type(torch.LongTensor))  # [num_mentions, emb]
        mention_end_emb = torch.index_select(
            text_outputs, 0,
            mention_ends.type(torch.LongTensor))  # [num_mentions, emb]
        mention_speaker_ids = torch.index_select(
            speaker_ids, 0,
            mention_starts.type(torch.LongTensor))  # [num_mentions]

        max_antecedents = self.config["max_antecedents"]
        antecedents, antecedent_labels, antecedents_len = coref_ops.coref_kernels_antecedents(
            mention_starts, mention_ends, gold_starts, gold_ends, cluster_ids,
            max_antecedents)
        # ([num_mentions, max_ant], [num_mentions, max_ant + 1], [num_mentions]
        antecedent_scores = self.get_antecedent_scores(
            mention_emb, mention_scores, antecedents, antecedents_len,
            mention_starts, mention_ends, mention_speaker_ids,
            genre_emb)  # [num_mentions, max_ant + 1]
        loss = self.softmax_loss(antecedent_scores,
                                 antecedent_labels)  # [num_mentions]
        loss2 = F.multilabel_margin_loss(
            antecedent_scores, antecedent_labels.type(torch.LongTensor))
        loss = torch.sum(loss)  # []
        return [
            candidate_starts, candidate_ends, candidate_mention_scores,
            mention_starts, mention_ends, antecedents, antecedent_scores
        ], loss
예제 #14
0
def main():

    # load data
    data_path = '../data/output/records_final.pkl'
    voc_path = '../data/output/voc_final.pkl'

    ddi_adj_path = '../data/output/ddi_A_final.pkl'
    ddi_mask_path = '../data/output/ddi_mask_H.pkl'
    molecule_path = '../data/output/atc3toSMILES.pkl'
    device = torch.device('cuda:{}'.format(args.cuda))

    ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
    ddi_mask_H = dill.load(open(ddi_mask_path, 'rb'))
    data = dill.load(open(data_path, 'rb'))
    molecule = dill.load(open(molecule_path, 'rb'))

    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[
        'med_voc']

    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point + eval_len:]

    MPNNSet, N_fingerprint, average_projection = buildMPNN(
        molecule, med_voc.idx2word, 2, device)
    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word),
                len(med_voc.idx2word))

    model = SafeDrugModel(voc_size,
                          ddi_adj,
                          ddi_mask_H,
                          MPNNSet,
                          N_fingerprint,
                          average_projection,
                          emb_dim=args.dim,
                          device=device)
    # model.load_state_dict(torch.load(open(args.resume_path, 'rb')))

    if args.Test:
        model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
        model.to(device=device)
        tic = time.time()

        ddi_list, ja_list, prauc_list, f1_list, med_list = [], [], [], [], []
        # ###
        # for threshold in np.linspace(0.00, 0.20, 30):
        #     print ('threshold = {}'.format(threshold))
        #     ddi, ja, prauc, _, _, f1, avg_med = eval(model, data_test, voc_size, 0, threshold)
        #     ddi_list.append(ddi)
        #     ja_list.append(ja)
        #     prauc_list.append(prauc)
        #     f1_list.append(f1)
        #     med_list.append(avg_med)
        # total = [ddi_list, ja_list, prauc_list, f1_list, med_list]
        # with open('ablation_ddi.pkl', 'wb') as infile:
        #     dill.dump(total, infile)
        # ###
        result = []
        for _ in range(10):
            test_sample = np.random.choice(data_test,
                                           round(len(data_test) * 0.8),
                                           replace=True)
            ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(
                model, test_sample, voc_size, 0)
            result.append([ddi_rate, ja, avg_f1, prauc, avg_med])

        result = np.array(result)
        mean = result.mean(axis=0)
        std = result.std(axis=0)

        outstring = ""
        for m, s in zip(mean, std):
            outstring += "{:.4f} $\pm$ {:.4f} & ".format(m, s)

        print(outstring)

        print('test time: {}'.format(time.time() - tic))
        return

    model.to(device=device)
    # print('parameters', get_n_params(model))
    # exit()
    optimizer = Adam(list(model.parameters()), lr=args.lr)

    # start iterations
    history = defaultdict(list)
    best_epoch, best_ja = 0, 0

    EPOCH = 50
    for epoch in range(EPOCH):
        tic = time.time()
        print('\nepoch {} --------------------------'.format(epoch + 1))

        model.train()
        for step, input in enumerate(data_train):

            loss = 0
            for idx, adm in enumerate(input):

                seq_input = input[:idx + 1]
                loss_bce_target = np.zeros((1, voc_size[2]))
                loss_bce_target[:, adm[2]] = 1

                loss_multi_target = np.full((1, voc_size[2]), -1)
                for idx, item in enumerate(adm[2]):
                    loss_multi_target[0][idx] = item

                result, loss_ddi = model(seq_input)

                loss_bce = F.binary_cross_entropy_with_logits(
                    result,
                    torch.FloatTensor(loss_bce_target).to(device))
                loss_multi = F.multilabel_margin_loss(
                    F.sigmoid(result),
                    torch.LongTensor(loss_multi_target).to(device))

                result = F.sigmoid(result).detach().cpu().numpy()[0]
                result[result >= 0.5] = 1
                result[result < 0.5] = 0
                y_label = np.where(result == 1)[0]
                current_ddi_rate = ddi_rate_score(
                    [[y_label]], path='../data/output/ddi_A_final.pkl')

                if current_ddi_rate <= args.target_ddi:
                    loss = 0.95 * loss_bce + 0.05 * loss_multi
                else:
                    beta = min(
                        0, 1 + (args.target_ddi - current_ddi_rate) / args.kp)
                    loss = beta * (0.95 * loss_bce +
                                   0.05 * loss_multi) + (1 - beta) * loss_ddi

                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

            llprint('\rtraining step: {} / {}'.format(step, len(data_train)))

        print()
        tic2 = time.time()
        ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, avg_med = eval(
            model, data_eval, voc_size, epoch)
        print('training time: {}, test time: {}'.format(
            time.time() - tic,
            time.time() - tic2))

        history['ja'].append(ja)
        history['ddi_rate'].append(ddi_rate)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)
        history['med'].append(avg_med)

        if epoch >= 5:
            print('ddi: {}, Med: {}, Ja: {}, F1: {}, PRAUC: {}'.format(
                np.mean(history['ddi_rate'][-5:]),
                np.mean(history['med'][-5:]), np.mean(history['ja'][-5:]),
                np.mean(history['avg_f1'][-5:]),
                np.mean(history['prauc'][-5:])))

        torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \
            'Epoch_{}_TARGET_{:.2}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, args.target_ddi, ja, ddi_rate)), 'wb'))

        if epoch != 0 and best_ja < ja:
            best_epoch = epoch
            best_ja = ja

        print('best_epoch: {}'.format(best_epoch))

    dill.dump(
        history,
        open(
            os.path.join('saved', args.model_name,
                         'history_{}.pkl'.format(args.model_name)), 'wb'))
예제 #15
0
def multilabel_margin_loss(input, target, *args, **kwargs):
    return F.multilabel_margin_loss(input.F, target, *args, **kwargs)
예제 #16
0
def main():
    if not os.path.exists(os.path.join("saved", model_name)):
        os.makedirs(os.path.join("saved", model_name))

    data_path = '../data/records_final.pkl'
    voc_path = '../data/voc_final.pkl'

    ehr_adj_path = '../data/ehr_adj_final.pkl'
    ddi_adj_path = '../data/ddi_A_final.pkl'
    device = torch.device('cuda:0')

    ehr_adj = dill.load(open(ehr_adj_path, 'rb'))
    ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
    data = dill.load(open(data_path, 'rb'))
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[
        'med_voc']

    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point + eval_len:]

    EPOCH = 30
    LR = 0.001
    TEST = True
    Neg_Loss = False
    TARGET_DDI = 0.001

    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word),
                len(med_voc.idx2word))
    model = GAMENet(voc_size,
                    ehr_adj,
                    ddi_adj,
                    emb_dim=64,
                    device=device,
                    with_memory=True)
    if TEST:
        model.load_state_dict(
            torch.load(
                open(os.path.join("saved", model_name, resume_name), 'rb')))
    model.to(device=device)

    optimizer = Adam(list(model.parameters()), lr=LR)

    if TEST:
        eval(model, data_test, voc_size, 0)
    else:
        history = defaultdict(list)
        for epoch in range(EPOCH):
            loss_record1 = []
            start_time = time.time()
            model.train()
            for step, input in enumerate(data_train):
                input1_hidden, input2_hidden, target_hidden = None, None, None
                loss = 0
                prev_target = None
                for adm in input:
                    loss1_target = np.zeros((1, voc_size[2]))
                    loss1_target[:, adm[2]] = 1

                    loss3_target = np.full((1, voc_size[2]), -1)
                    for idx, item in enumerate(adm[2]):
                        loss3_target[0][idx] = item

                    target_output1, [
                        input1_hidden, input2_hidden, target_hidden
                    ], batch_neg_loss = model(
                        adm, prev_target,
                        [input1_hidden, input2_hidden, target_hidden])
                    prev_target = adm[2]

                    loss1 = F.binary_cross_entropy_with_logits(
                        target_output1,
                        torch.FloatTensor(loss1_target).to(device))

                    # loss = 9*loss1/10 + loss2/10
                    loss3 = F.multilabel_margin_loss(
                        F.sigmoid(target_output1),
                        torch.LongTensor(loss3_target).to(device))
                    # loss += loss1 + 0.1*loss3 + 0.01*batch_neg_loss
                    if Neg_Loss:
                        # neg_loss_weight = 0.0007 * (2 ** (epoch // 5))
                        # if neg_loss_weight > 0.01:
                        #     # decay stop:
                        #     neg_loss_weight = 0.01
                        # loss += 0.9*loss1 + 0.02*loss3 + neg_loss_weight*batch_neg_loss
                        loss = 0.001 * batch_neg_loss
                    else:
                        loss = 0.9 * loss1 + 0.03 * loss3

                    optimizer.zero_grad()
                    loss.backward(retain_graph=True)
                    optimizer.step()

                    loss_record1.append(loss.item())

                llprint('\rTrain--Epoch: %d, Step: %d/%d' %
                        (epoch, step, len(data_train)))

            ddi_rate, ja, prauc, avg_p, avg_r, avg_f1 = eval(
                model, data_eval, voc_size, epoch)
            history['ja'].append(ja)
            history['ddi_rate'].append(ddi_rate)
            history['avg_p'].append(avg_p)
            history['avg_r'].append(avg_r)
            history['avg_f1'].append(avg_f1)
            history['prauc'].append(prauc)

            end_time = time.time()
            elapsed_time = (end_time - start_time) / 60
            llprint(
                '\tEpoch: %d, Loss: %.4f, One Epoch Time: %.2fm, Appro Left Time: %.2fh\n'
                % (epoch, np.mean(loss_record1), elapsed_time, elapsed_time *
                   (EPOCH - epoch - 1) / 60))

            torch.save(
                model.state_dict(),
                open(
                    os.path.join(
                        'saved', model_name,
                        'Epoch_%d_JA_%.4f_DDI_%.4f.model' %
                        (epoch, ja, ddi_rate)), 'wb'))
            print('')

        dill.dump(history,
                  open(os.path.join('saved', model_name, 'history.pkl'), 'wb'))

        # test
        torch.save(
            model.state_dict(),
            open(os.path.join('saved', model_name, 'final.model'), 'wb'))
예제 #17
0
def main():
    if not os.path.exists(os.path.join("saved", model_name)):
        os.makedirs(os.path.join("saved", model_name))

    data_path = '../data/records_final.pkl'
    voc_path = '../data/voc_final.pkl'

    ehr_adj_path = '../data/ehr_adj_final.pkl'
    ddi_adj_path = '../data/ddi_A_final.pkl'
    device = torch.device('cuda:0')

    ehr_adj = dill.load(open(ehr_adj_path, 'rb'))
    ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
    data = dill.load(open(data_path, 'rb'))
    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[
        'med_voc']

    split_point = int(len(data) * 2 / 3)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point + eval_len:]

    EPOCH = 40
    LR = 0.0002
    TEST = args.eval
    Neg_Loss = args.ddi
    DDI_IN_MEM = args.ddi
    TARGET_DDI = 0.05
    T = 0.5
    decay_weight = 0.85

    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word),
                len(med_voc.idx2word))
    model = GAMENet(voc_size,
                    ehr_adj,
                    ddi_adj,
                    emb_dim=64,
                    device=device,
                    ddi_in_memory=DDI_IN_MEM)
    if TEST:
        model.load_state_dict(torch.load(open(resume_name, 'rb')))
    model.to(device=device)

    print('parameters', get_n_params(model))
    optimizer = Adam(list(model.parameters()), lr=LR)

    if TEST:
        eval(model, data_test, voc_size, 0)
    else:
        history = defaultdict(list)
        best_epoch = 0
        best_ja = 0
        for epoch in range(EPOCH):
            loss_record1 = []
            start_time = time.time()
            model.train()
            prediction_loss_cnt = 0
            neg_loss_cnt = 0
            for step, input in enumerate(data_train):
                for idx, adm in enumerate(input):
                    seq_input = input[:idx + 1]
                    loss1_target = np.zeros((1, voc_size[2]))
                    loss1_target[:, adm[2]] = 1
                    loss3_target = np.full((1, voc_size[2]), -1)
                    for idx, item in enumerate(adm[2]):
                        loss3_target[0][idx] = item

                    target_output1, batch_neg_loss = model(seq_input)

                    loss1 = F.binary_cross_entropy_with_logits(
                        target_output1,
                        torch.FloatTensor(loss1_target).to(device))
                    loss3 = F.multilabel_margin_loss(
                        F.sigmoid(target_output1),
                        torch.LongTensor(loss3_target).to(device))
                    if Neg_Loss:
                        target_output1 = F.sigmoid(
                            target_output1).detach().cpu().numpy()[0]
                        target_output1[target_output1 >= 0.5] = 1
                        target_output1[target_output1 < 0.5] = 0
                        y_label = np.where(target_output1 == 1)[0]
                        current_ddi_rate = ddi_rate_score([[y_label]])
                        if current_ddi_rate <= TARGET_DDI:
                            loss = 0.9 * loss1 + 0.01 * loss3
                            prediction_loss_cnt += 1
                        else:
                            rnd = np.exp((TARGET_DDI - current_ddi_rate) / T)
                            if np.random.rand(1) < rnd:
                                loss = batch_neg_loss
                                neg_loss_cnt += 1
                            else:
                                loss = 0.9 * loss1 + 0.01 * loss3
                                prediction_loss_cnt += 1
                    else:
                        loss = 0.9 * loss1 + 0.01 * loss3

                    optimizer.zero_grad()
                    loss.backward(retain_graph=True)
                    optimizer.step()

                    loss_record1.append(loss.item())

                llprint(
                    '\rTrain--Epoch: %d, Step: %d/%d, L_p cnt: %d, L_neg cnt: %d'
                    % (epoch, step, len(data_train), prediction_loss_cnt,
                       neg_loss_cnt))
            # annealing
            T *= decay_weight

            ddi_rate, ja, prauc, avg_p, avg_r, avg_f1 = eval(
                model, data_eval, voc_size, epoch)

            history['ja'].append(ja)
            history['ddi_rate'].append(ddi_rate)
            history['avg_p'].append(avg_p)
            history['avg_r'].append(avg_r)
            history['avg_f1'].append(avg_f1)
            history['prauc'].append(prauc)

            end_time = time.time()
            elapsed_time = (end_time - start_time) / 60
            llprint(
                '\tEpoch: %d, Loss: %.4f, One Epoch Time: %.2fm, Appro Left Time: %.2fh\n'
                % (epoch, np.mean(loss_record1), elapsed_time, elapsed_time *
                   (EPOCH - epoch - 1) / 60))

            torch.save(
                model.state_dict(),
                open(
                    os.path.join(
                        'saved', model_name,
                        'Epoch_%d_JA_%.4f_DDI_%.4f.model' %
                        (epoch, ja, ddi_rate)), 'wb'))
            print('')
            if epoch != 0 and best_ja < ja:
                best_epoch = epoch
                best_ja = ja

        dill.dump(history,
                  open(os.path.join('saved', model_name, 'history.pkl'), 'wb'))

        # test
        torch.save(
            model.state_dict(),
            open(os.path.join('saved', model_name, 'final.model'), 'wb'))

        print('best_epoch:', best_epoch)