def run_evaluation_bert(args, checkpoint, test_loader, vocab_size): device = args.device model = BERT().to(device) # model = nn.DataParallel(model) model.load_state_dict(checkpoint['model_state_dict']) model.eval() answer_file = open(args.result_path+'/answer.txt', "w") # For ensemble logit_file = open(args.result_path + '/logit.txt', "w") for i, batch in enumerate(test_loader): text, context = batch.text, batch.context text = text.type(torch.LongTensor).to(device) output = model.run_eval(text) pred = torch.argmax(output, 1).tolist() assert len(pred) == 1 if pred[0] == 1: label = 'SARCASM' elif pred[0] == 0: label = 'NOT_SARCASM' else: raise NotImplementedError("Strange pred.") answer_file.write("twitter_{},{}".format(i+1, label)) answer_file.write('\n') logit_file.write("{},{}".format(output[0][0], output[0][1])) logit_file.write("\n") answer_file.close() logit_file.close()
def run_training_bert(args, dataset, train_loader, val_loader, vocab_size): checkpoint_path = os.path.join(args.checkpoint_path, args.checkpoint) device = torch.device("cuda:" + args.device if torch.cuda.is_available() else "cpu") model = BERT().to(device) # Initialize BCELoss function # criterion = nn.BCEWithLogitsLoss() # Setup Adam optimizers for both G and D optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) model.train() # turn on training mode # Training Loop print("Starting Training Loop...") # For each epoch for epoch in range(args.epochs): # For each batch in the dataloader losses = [] running_corrects = 0 for i, batch in enumerate(train_loader): # format batch text, context, label = batch.text, batch.context, batch.label # print(text.tolist()[0]) # print(label.tolist()[0]) label = label.type(torch.LongTensor).to(device) text = text.type(torch.LongTensor).to(device) output = model(text, label) loss, _ = output optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) epoch_loss = sum(losses) / len(losses) print('Epoch: {}, Training Loss: {:.4f}'.format(epoch, epoch_loss)) # save model if epoch % 1 == 0 or epoch == args.epochs - 1: torch.save( { 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'vocab_size': vocab_size, 'args': vars(args) }, checkpoint_path) if args.eval: model.eval() with torch.no_grad(): preds = [] labels = [] eval_losses = [] for i, batch in enumerate(val_loader if val_loader is not None else train_loader): text, context, label = batch.text, batch.context, batch.label label = label.type(torch.LongTensor).to(device) text = text.type(torch.LongTensor).to(device) output = model(text, label) loss, output = output pred = torch.argmax(output, 1).tolist() preds.extend(pred) labels.extend(label.tolist()) eval_losses.append(loss.item()) print("{} Precision: {}, Recall: {}, F1: {}, Loss: {}". format( "Train" if val_loader is None else "Valid", sklearn.metrics.precision_score( np.array(labels).astype('int32'), np.array(preds)), sklearn.metrics.recall_score( np.array(labels).astype('int32'), np.array(preds)), sklearn.metrics.f1_score( np.array(labels).astype('int32'), np.array(preds)), np.average(eval_losses)))