Exemplo n.º 1
0
def test_engine(state_dict=None, cfg=None, dataset=None):
    net = Net(cfg)
    net.eval()
    net.cuda()
    flag = 'val'
    if state_dict == None:
        flag = 'test'
        dataset = DiagramDataset(cfg)
        net.load_state_dict(
            torch.load(cfg.save_path + '/' + cfg.csdia_t + '/ban/' +
                       cfg.version + '/epoch' + cfg.epoch + '.pkl'))
    else:
        net.load_state_dict(state_dict)

    criterion = CrossEntropyLoss()
    print('Note: begin to test the model using ' + flag + ' split')
    dataloader = Data.DataLoader(
        dataset=dataset,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        pin_memory=True,
    )
    loss_sum = 0
    correct_sum = 0
    que_sum = 0
    for step, (que_iter, dia_f_iter, opt_iter, dia_mat_iter, dia_nod_emb_iter,
               ans_iter) in enumerate(dataloader):
        que_iter = que_iter.cuda()
        dia_f_iter = dia_f_iter.cuda()
        opt_iter = opt_iter.cuda()
        dia_mat_iter = dia_mat_iter.cuda()
        dia_nod_emb_iter = dia_nod_emb_iter.cuda()
        ans_iter = ans_iter.cuda()

        with torch.no_grad():
            pred = net(que_iter, dia_f_iter, opt_iter, dia_mat_iter,
                       dia_nod_emb_iter, cfg)

            _, pred_ix = torch.max(pred, -1)
            _, label_ix = torch.max(ans_iter, -1)

            label_ix = label_ix.squeeze(-1)
            loss = criterion(pred, label_ix)
            # print(loss)
            loss_sum += loss

            correct_sum += label_ix.eq(pred_ix).cpu().sum()
            que_sum += que_iter.shape[0]
    correct_sum = np.array(correct_sum, dtype='float32')
    overall_acc = correct_sum / que_sum

    print(40 * '*', '\n', 'loss: {}'.format(loss_sum / que_sum), '\n',
          'correct sum:', correct_sum, '\n', 'total questions:', que_sum, '\n',
          'overall accuracy:', overall_acc)
    print(40 * '*')
    print('\n')
Exemplo n.º 2
0
def run_diagram_net(cfg):
    net = Net(cfg)

    net.cuda()
    net.train()
    criterion = CrossEntropyLoss()
    optimizer = Adam(net.parameters(),
                     lr=cfg.lr,
                     weight_decay=cfg.weight_decay)
    scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    train_dataset = DiagramDataset(cfg)

    train_dataloader = Data.DataLoader(dataset=train_dataset,
                                       batch_size=cfg.batch_size,
                                       shuffle=True,
                                       num_workers=cfg.num_workers,
                                       pin_memory=True,
                                       drop_last=True)
    cfg.splits = 'val'
    val_dataset = DiagramDataset(cfg)

    for epoch in range(cfg.max_epochs):
        loss_sum = 0
        correct_sum = 0
        que_sum = 0
        for step, (que_iter, dia_f_iter, opt_iter, dia_mat_iter,
                   dia_nod_emb_iter, ans_iter) in enumerate(train_dataloader):
            que_iter = que_iter.cuda()
            dia_f_iter = dia_f_iter.cuda()
            opt_iter = opt_iter.cuda()
            dia_mat_iter = dia_mat_iter.cuda()
            dia_nod_emb_iter = dia_nod_emb_iter.cuda()
            ans_iter = ans_iter.cuda()

            optimizer.zero_grad()
            # que_emb, opt_emb, adjacency_matrices, node_emb, cfg
            pred = net(que_iter, dia_f_iter, opt_iter, dia_mat_iter,
                       dia_nod_emb_iter, cfg)
            # loss = criterion(pred, ans_iter)
            _, label_ix = torch.max(ans_iter, -1)
            _, pred_ix = torch.max(pred, -1)

            label_ix = label_ix.squeeze(-1)
            loss = criterion(pred, label_ix)
            loss_sum += loss

            loss.backward()
            clip_grad_norm_(net.parameters(), 10)
            optimizer.step()

            correct_sum += label_ix.eq(pred_ix).cpu().sum()
            que_sum += que_iter.shape[0]
        correct_sum = np.array(correct_sum, dtype='float32')
        overall_acc = correct_sum / que_sum

        print(40 * '=', '\n', 'epoch:', epoch, '\n',
              'loss: {}'.format(loss_sum / que_sum), '\n', 'correct sum:',
              correct_sum, '\n', 'total questions:', que_sum, '\n',
              'accuracy:', overall_acc)
        print(40 * '=')
        print('\n')
        state = net.state_dict()
        if not os.path.exists(os.path.join(cfg.save_path, cfg.csdia_t)):
            os.mkdir(os.path.join(cfg.save_path, cfg.csdia_t))
        if not os.path.exists(os.path.join(cfg.save_path, cfg.csdia_t,
                                           'mcan')):
            os.mkdir(os.path.join(cfg.save_path, cfg.csdia_t, 'mcan'))
        if ('ckpt_' + cfg.version) not in os.listdir(
                os.path.join(cfg.save_path, cfg.csdia_t, 'mcan')):
            os.mkdir(
                os.path.join(cfg.save_path, cfg.csdia_t, 'mcan',
                             'ckpt_' + cfg.version))

        torch.save(
            state, cfg.save_path + '/' + cfg.csdia_t + '/mcan'
            '/ckpt_' + cfg.version + '/epoch' + str(epoch) + '.pkl')
        test_engine(state, cfg, val_dataset)
        scheduler.step()