Beispiel #1
0
def run_evaluation_bert(args, checkpoint, test_loader, vocab_size):
    device = args.device
    model = BERT().to(device)
    # model = nn.DataParallel(model)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    answer_file = open(args.result_path+'/answer.txt', "w")
    # For ensemble
    logit_file = open(args.result_path + '/logit.txt', "w")
    for i, batch in enumerate(test_loader):
        text, context = batch.text, batch.context
        text = text.type(torch.LongTensor).to(device)
        output = model.run_eval(text)
        pred = torch.argmax(output, 1).tolist()
        assert len(pred) == 1
        if pred[0] == 1:
            label = 'SARCASM'
        elif pred[0] == 0:
            label = 'NOT_SARCASM'
        else:
            raise NotImplementedError("Strange pred.")
        answer_file.write("twitter_{},{}".format(i+1, label))
        answer_file.write('\n')
        logit_file.write("{},{}".format(output[0][0], output[0][1]))
        logit_file.write("\n")
    answer_file.close()
    logit_file.close()
Beispiel #2
0
def main(args):
    train_loader, test_loader = load_data(args)

    if not os.path.isdir('checkpoints'):
        os.mkdir('checkpoints')

    args.vocab_len = len(args.vocab['stoi'].keys())

    model = BERT(args.vocab_len, args.max_len, args.heads, args.embedding_dim,
                 args.N)
    if args.cuda:
        model = model.cuda()

    if args.task:
        print('Start Down Stream Task')
        args.epochs = 3
        args.lr = 3e-5

        state_dict = torch.load(args.checkpoints)
        model.load_state_dict(state_dict['model_state_dict'])

        criterion = {'mlm': None, 'nsp': nn.CrossEntropyLoss()}

        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

        for epoch in range(1, args.epochs + 1):
            train_mlm_loss, train_nsp_loss, train_loss, train_mlm_acc, train_nsp_acc = _train(
                epoch, train_loader, model, optimizer, criterion, args)
            test_mlm_loss, test_nsp_loss, test_loss, test_mlm_acc, test_nsp_acc = _eval(
                epoch, test_loader, model, criterion, args)
            save_checkpoint(model, optimizer, args, epoch)
    else:
        print('Start Pre-training')
        criterion = {
            'mlm': nn.CrossEntropyLoss(ignore_index=0),
            'nsp': nn.CrossEntropyLoss()
        }
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

        for epoch in range(1, args.epochs):
            train_mlm_loss, train_nsp_loss, train_loss, train_mlm_acc, train_nsp_acc = _train(
                epoch, train_loader, model, optimizer, criterion, args)
            test_mlm_loss, test_nsp_loss, test_loss, test_mlm_acc, test_nsp_acc = _eval(
                epoch, test_loader, model, criterion, args)
            save_checkpoint(model, optimizer, args, epoch)
Beispiel #3
0
    # TODO: Load dataset

    train_dataset = MyDataset(args.train_file)
    word2vec = train_dataset.symbol2id
    test_dataset = MyDataset(args.test_file,
                             vocab=(train_dataset.symbol2id,
                                    train_dataset.id2sybmol))

    test_loader = DataLoader(test_dataset,
                             batch_size=hyperparams['batch_size'])
    num_tokens = len(train_dataset.id2sybmol)
    print('num tokens', num_tokens)
    print('size', test_dataset.seq.size())

    model = BERT(num_tokens).to(device)

    if args.load:
        # print("Model's state_dict:")
        # for param_tensor in torch.load('./model.pt', map_location=torch.device(device)):
        #     print(param_tensor, "\t", model.state_dict()[param_tensor].size())
        model.load_state_dict(
            torch.load('./model.pt', map_location=torch.device(device)))
    if args.train:
        train(model, train_dataset, experiment, hyperparams)
    if args.save:
        torch.save(model.state_dict(), './model.pt')
    if args.test:
        test(model, test_loader, experiment, hyperparams)
    if args.analysis:
        embedding_analysis(model, experiment, train_dataset, test_dataset)
Beispiel #4
0
        for j in range(len(preds)):
            total += 1
            if preds[j] == target[j]:
                total_correct += 1

    return total_correct/total


if __name__ == '__main__':
    mnli = BERTMNLI(TRAIN_DATA_DIR, bert_type=BERT_TYPE)
    match = BERTMNLI(MATCH_DATA_DIR, bert_type=BERT_TYPE)
    mismatch = BERTMNLI(MISMATCH_DATA_DIR, bert_type=BERT_TYPE)

    checkpoint = torch.load('storage/bert-base-dnli.pt')
    model = BERT(bert_type=BERT_TYPE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    ###

    optimizer = Adam(model.parameters(), lr = LEARNING_RATE)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    criterion = nn.CrossEntropyLoss()

    best_acc = 0

    for epoch in range(1, NUM_EPOCHS+1):
        train_loss = train(mnli, model, criterion, optimizer, device)
        match_acc = eval(match, model, device)
        mismatch_acc= eval(mismatch, model, device)
       # print(f'Epoch {epoch}')