def test_engine(state_dict=None, cfg=None, dataset=None): net = Net(cfg) net.eval() net.cuda() flag = 'val' if state_dict == None: flag = 'test' dataset = DiagramDataset(cfg) net.load_state_dict( torch.load(cfg.save_path + '/' + cfg.csdia_t + '/ban/' + cfg.version + '/epoch' + cfg.epoch + '.pkl')) else: net.load_state_dict(state_dict) criterion = CrossEntropyLoss() print('Note: begin to test the model using ' + flag + ' split') dataloader = Data.DataLoader( dataset=dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, pin_memory=True, ) loss_sum = 0 correct_sum = 0 que_sum = 0 for step, (que_iter, dia_f_iter, opt_iter, dia_mat_iter, dia_nod_emb_iter, ans_iter) in enumerate(dataloader): que_iter = que_iter.cuda() dia_f_iter = dia_f_iter.cuda() opt_iter = opt_iter.cuda() dia_mat_iter = dia_mat_iter.cuda() dia_nod_emb_iter = dia_nod_emb_iter.cuda() ans_iter = ans_iter.cuda() with torch.no_grad(): pred = net(que_iter, dia_f_iter, opt_iter, dia_mat_iter, dia_nod_emb_iter, cfg) _, pred_ix = torch.max(pred, -1) _, label_ix = torch.max(ans_iter, -1) label_ix = label_ix.squeeze(-1) loss = criterion(pred, label_ix) # print(loss) loss_sum += loss correct_sum += label_ix.eq(pred_ix).cpu().sum() que_sum += que_iter.shape[0] correct_sum = np.array(correct_sum, dtype='float32') overall_acc = correct_sum / que_sum print(40 * '*', '\n', 'loss: {}'.format(loss_sum / que_sum), '\n', 'correct sum:', correct_sum, '\n', 'total questions:', que_sum, '\n', 'overall accuracy:', overall_acc) print(40 * '*') print('\n')
def run_diagram_net(cfg): net = Net(cfg) net.cuda() net.train() criterion = CrossEntropyLoss() optimizer = Adam(net.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9) train_dataset = DiagramDataset(cfg) train_dataloader = Data.DataLoader(dataset=train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, drop_last=True) cfg.splits = 'val' val_dataset = DiagramDataset(cfg) for epoch in range(cfg.max_epochs): loss_sum = 0 correct_sum = 0 que_sum = 0 for step, (que_iter, dia_f_iter, opt_iter, dia_mat_iter, dia_nod_emb_iter, ans_iter) in enumerate(train_dataloader): que_iter = que_iter.cuda() dia_f_iter = dia_f_iter.cuda() opt_iter = opt_iter.cuda() dia_mat_iter = dia_mat_iter.cuda() dia_nod_emb_iter = dia_nod_emb_iter.cuda() ans_iter = ans_iter.cuda() optimizer.zero_grad() # que_emb, opt_emb, adjacency_matrices, node_emb, cfg pred = net(que_iter, dia_f_iter, opt_iter, dia_mat_iter, dia_nod_emb_iter, cfg) # loss = criterion(pred, ans_iter) _, label_ix = torch.max(ans_iter, -1) _, pred_ix = torch.max(pred, -1) label_ix = label_ix.squeeze(-1) loss = criterion(pred, label_ix) loss_sum += loss loss.backward() clip_grad_norm_(net.parameters(), 10) optimizer.step() correct_sum += label_ix.eq(pred_ix).cpu().sum() que_sum += que_iter.shape[0] correct_sum = np.array(correct_sum, dtype='float32') overall_acc = correct_sum / que_sum print(40 * '=', '\n', 'epoch:', epoch, '\n', 'loss: {}'.format(loss_sum / que_sum), '\n', 'correct sum:', correct_sum, '\n', 'total questions:', que_sum, '\n', 'accuracy:', overall_acc) print(40 * '=') print('\n') state = net.state_dict() if not os.path.exists(os.path.join(cfg.save_path, cfg.csdia_t)): os.mkdir(os.path.join(cfg.save_path, cfg.csdia_t)) if not os.path.exists(os.path.join(cfg.save_path, cfg.csdia_t, 'mcan')): os.mkdir(os.path.join(cfg.save_path, cfg.csdia_t, 'mcan')) if ('ckpt_' + cfg.version) not in os.listdir( os.path.join(cfg.save_path, cfg.csdia_t, 'mcan')): os.mkdir( os.path.join(cfg.save_path, cfg.csdia_t, 'mcan', 'ckpt_' + cfg.version)) torch.save( state, cfg.save_path + '/' + cfg.csdia_t + '/mcan' '/ckpt_' + cfg.version + '/epoch' + str(epoch) + '.pkl') test_engine(state, cfg, val_dataset) scheduler.step()