def test(dset_test, batch_size, baseGRU): test_loader = DL(dset_test, batch_size=batch_size, shuffle=True, collate_fn=pad_collate) vocab_size = len(dset_test.QA.VOCAB) losses = [] accs = [] for batch_idx, data in enumerate(test_loader): contexts, questions, answers = data b_size = len(contexts) c = contexts.view(b_size, -1).long() q = questions.view(b_size, -1).long() c.transpose_(0, 1) q.transpose_(0, 1) hidden1, hidden2 = baseGRU.initHidden(b_size) out = baseGRU(c, q, hidden1, hidden2) topv, topi = out.data.topk(1) topi = topi.view(1, -1).squeeze(0) acc = torch.mean((topi.data == answers.data).float()) accs.append(acc.item()) if batch_idx % 10 == 0: visualizeSample(dset_test, contexts[0], questions[0], answers[0], topi[0]) #if batch_idx == 2: break accuracy = sum(accs) / len(accs) print("test accuracy is: %f" % accuracy)
def train(dset_train, batch_size, epoch, dmn): optim = torch.optim.Adam(dmn.parameters()) losses = [] accs = [] for e in range(epoch): train_loader = DL(dset_train, batch_size=batch_size, shuffle=True, collate_fn=pad_collate) dmn.train() print('epoch %d is in progress' % e) for batch_idx, data in enumerate(train_loader): contexts, questions, answers = data b_size = len(contexts) optim.zero_grad() c = Variable(contexts.long().cuda()) q = Variable(questions.long().cuda()) answers = Variable(answers.cuda()) loss, acc, _ = dmn.get_loss(c, q, answers) losses.append(loss.item()) accs.append(acc.item()) loss.backward() optim.step() #if batch_idx == 50: break if e % 16 == 0: plt.figure() plt.plot(losses) plt.title('training loss') plt.show() plt.figure() plt.plot(accs) plt.title('training accuracy') plt.show()
def test(dset_test, batch_size, dmn): test_loader = DL(dset_test, batch_size=batch_size, shuffle=True, collate_fn=pad_collate) vocab_size = len(dset_test.QA.VOCAB) losses = [] accs = [] for batch_idx, data in enumerate(test_loader): contexts, questions, answers = data b_size = len(contexts) c = Variable(contexts.long()) q = Variable(questions.long()) answers = Variable(answers) _, acc, topi = dmn.get_loss(c, q, answers) accs.append(acc.item()) if batch_idx % 10 == 0: visualizeSample(dset_test, contexts[0], questions[0], answers[0], topi[0]) #if batch_idx == 2: break accuracy = sum(accs) / len(accs) print("test accuracy is: %f" % accuracy)
def train(dset_train, batch_size, epoch, baseGRU): criterion = nn.NLLLoss() optim = torch.optim.Adam(baseGRU.parameters()) losses = [] accs = [] for e in range(epoch): train_loader = DL(dset_train, batch_size=batch_size, shuffle=True, collate_fn=pad_collate) print('epoch %d is in progress' % e) for batch_idx, data in enumerate(train_loader): contexts, questions, answers = data b_size = len(contexts) optim.zero_grad() c = contexts.view(b_size, -1).long() q = questions.view(b_size, -1).long() c.transpose_(0, 1) q.transpose_(0, 1) hidden1, hidden2 = baseGRU.initHidden(b_size) out = baseGRU(c, q, hidden1, hidden2) topv, topi = out.data.topk(1) topi = topi.view(1, -1) acc = torch.mean((topi.data == answers.data).float()) loss = criterion(out, answers) losses.append(loss.item()) accs.append(acc.item()) loss.backward() optim.step() #if batch_idx == 50: break if e % 16 == 0: plt.figure() plt.plot(losses) plt.title('training loss') plt.show() plt.figure() plt.plot(accs) plt.title('training accuracy') plt.show()
def train(dset, batch_size, epochs, dmn): early_stopping_cnt = 0 early_stopping_flag = False best_acc = 0 optim = torch.optim.Adam(dmn.parameters()) for epoch in range(epochs): dset.set_mode('train') train_loader = DL( dset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate ) dmn.train() if not early_stopping_flag: total_acc = 0 cnt = 0 for batch_idx, data in enumerate(train_loader): optim.zero_grad() contexts, questions, answers = data b_size = contexts.size()[0] contexts = Variable(contexts.long()) questions = Variable(questions.long()) answers = Variable(answers) loss, acc, _ = dmn.get_loss(contexts, questions, answers) loss.backward() total_acc += acc * b_size cnt += b_size if batch_idx % 20 == 0: #print(f'[Task {task_id}, Epoch {epoch}] [Training] loss : {loss.data[0]: {10}.{8}}, acc : {total_acc / cnt: {5}.{4}}, batch_idx : {batch_idx}') print('[Epoch %d] [Training] loss : %f, acc : %f, batch_idx : %d' % (epoch, loss.data[0], (total_acc/cnt), batch_idx)) optim.step() dset.set_mode('valid') valid_loader = DL( dset, batch_size=batch_size, shuffle=False, collate_fn=pad_collate ) dmn.eval() total_acc = 0 cnt = 0 for batch_idx, data in enumerate(valid_loader): contexts, questions, answers = data b_size = contexts.size()[0] contexts = Variable(contexts.long()) questions = Variable(questions.long()) answers = Variable(answers) _, acc, _ = dmn.get_loss(contexts, questions, answers) total_acc += acc * batch_size cnt += batch_size total_acc = total_acc / cnt if total_acc > best_acc: best_acc = total_acc best_state = dmn.state_dict() early_stopping_cnt = 0 else: early_stopping_cnt += 1 if early_stopping_cnt > 20: early_stopping_flag = True #print(f'[Run {run}, Task {task_id}, Epoch {epoch}] [Validate] Accuracy : {total_acc: {5}.{4}}') print('[Epoch %d] [Validate] Accuracy : %f' % (epoch, total_acc)) if total_acc == 1.0: break else: #print(f'[Run {run}, Task {task_id}] Early Stopping at Epoch {epoch}, Valid Accuracy : {best_acc: {5}.{4}}') print('Early Stopping at Epoch %d, Valid Accuracy : %f' % (epoch, best_acc)) break dmn.load_state_dict(best_state)