Beispiel #1
0
def run_evaluation_bert(args, checkpoint, test_loader, vocab_size):
    device = args.device
    model = BERT().to(device)
    # model = nn.DataParallel(model)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    answer_file = open(args.result_path+'/answer.txt', "w")
    # For ensemble
    logit_file = open(args.result_path + '/logit.txt', "w")
    for i, batch in enumerate(test_loader):
        text, context = batch.text, batch.context
        text = text.type(torch.LongTensor).to(device)
        output = model.run_eval(text)
        pred = torch.argmax(output, 1).tolist()
        assert len(pred) == 1
        if pred[0] == 1:
            label = 'SARCASM'
        elif pred[0] == 0:
            label = 'NOT_SARCASM'
        else:
            raise NotImplementedError("Strange pred.")
        answer_file.write("twitter_{},{}".format(i+1, label))
        answer_file.write('\n')
        logit_file.write("{},{}".format(output[0][0], output[0][1]))
        logit_file.write("\n")
    answer_file.close()
    logit_file.close()
Beispiel #2
0
def run_training_bert(args, dataset, train_loader, val_loader, vocab_size):
    checkpoint_path = os.path.join(args.checkpoint_path, args.checkpoint)
    device = torch.device("cuda:" +
                          args.device if torch.cuda.is_available() else "cpu")

    model = BERT().to(device)

    # Initialize BCELoss function
    # criterion = nn.BCEWithLogitsLoss()
    # Setup Adam optimizers for both G and D
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    model.train()  # turn on training mode
    # Training Loop
    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(args.epochs):
        # For each batch in the dataloader
        losses = []
        running_corrects = 0
        for i, batch in enumerate(train_loader):
            # format batch
            text, context, label = batch.text, batch.context, batch.label
            # print(text.tolist()[0])
            # print(label.tolist()[0])
            label = label.type(torch.LongTensor).to(device)
            text = text.type(torch.LongTensor).to(device)

            output = model(text, label)
            loss, _ = output

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
        epoch_loss = sum(losses) / len(losses)
        print('Epoch: {}, Training Loss: {:.4f}'.format(epoch, epoch_loss))
        # save model
        if epoch % 1 == 0 or epoch == args.epochs - 1:
            torch.save(
                {
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'vocab_size': vocab_size,
                    'args': vars(args)
                }, checkpoint_path)
            if args.eval:
                model.eval()
                with torch.no_grad():
                    preds = []
                    labels = []
                    eval_losses = []
                    for i, batch in enumerate(val_loader if val_loader
                                              is not None else train_loader):
                        text, context, label = batch.text, batch.context, batch.label
                        label = label.type(torch.LongTensor).to(device)
                        text = text.type(torch.LongTensor).to(device)
                        output = model(text, label)
                        loss, output = output
                        pred = torch.argmax(output, 1).tolist()
                        preds.extend(pred)
                        labels.extend(label.tolist())
                        eval_losses.append(loss.item())
                    print("{} Precision: {}, Recall: {}, F1: {}, Loss: {}".
                          format(
                              "Train" if val_loader is None else "Valid",
                              sklearn.metrics.precision_score(
                                  np.array(labels).astype('int32'),
                                  np.array(preds)),
                              sklearn.metrics.recall_score(
                                  np.array(labels).astype('int32'),
                                  np.array(preds)),
                              sklearn.metrics.f1_score(
                                  np.array(labels).astype('int32'),
                                  np.array(preds)), np.average(eval_losses)))