def training(device, w2v_model, train_iter, dev_iter, test_iter, batch_size,
             num_epoch, lr, weight_decay, ckp, max_acc):

    embedding_matrix = w2v_model.wv.vectors
    input_size, hidden_size = embedding_matrix.shape[
        0], embedding_matrix.shape[1]
    loss_func = torch.nn.CrossEntropyLoss()
    net = model.ESIM(input_size, hidden_size, 4, embedding_matrix).to(device)
    #net.load_state_dict(torch.load(ckp, map_location='cpu'))
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)

    for epoch in range(num_epoch):
        net.train()
        train_l, n = 0.0, 0
        start = datetime.datetime.now()
        out_epoch, label_epoch = [], []
        for batch in train_iter:

            seq1 = batch.sentence1
            seq2 = batch.sentence2
            label = batch.label
            mask1 = (seq1 == 1)
            mask2 = (seq2 == 1)
            out = net(seq1.to(device), seq2.to(device), mask1.to(device),
                      mask2.to(device))

            loss = loss_func(out, label.squeeze(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            prediction = out.argmax(dim=1).data.cpu().numpy().tolist()
            label = label.view(1, -1).squeeze().data.cpu().numpy().tolist()

            out_epoch.extend(prediction)
            label_epoch.extend(label)

            train_l += loss.item()
            n += 1

        train_acc = accuracy_score(label_epoch, out_epoch)

        dev_loss, dev_acc, max_acc = dev_evaluate(device, net, dev_iter,
                                                  max_acc, ckp)
        test_loss, test_acc = test_evaluate(device, net, test_iter)
        end = datetime.datetime.now()

        print(
            'epoch %d, train_acc %f, dev_acc %f, test_acc %f, max_acc %f, time %s'
            % (epoch + 1, train_acc, dev_acc, test_acc, max_acc, end - start))
Exemple #2
0
# 定义超参数
parser = argparse.ArgumentParser()
#本地
#parser.add_argument('--data_path', type=str, default="/Users/ren/Desktop/nlp相关/实验1/aclImdb/")#文件路径
parser.add_argument('--data_path', type=str, default="data/")#文件路径
parser.add_argument('--embed_size', type=int, default=300)#embeding层宽度
parser.add_argument('--hidden_size', type=int, default=128)
parser.add_argument('--seq_len', type=int, default=20)#文件长度,需要截断和填充
parser.add_argument('--batch_size', type=int, default=64)#批次
parser.add_argument('--bidirectional', type=bool, default=True)#是否开启双向
parser.add_argument('--classification_num', type=int, default=4)#分类个数
parser.add_argument('--lr', type=float, default=1e-3)#学习率
parser.add_argument('--dropout', type=float, default=0.5)#丢弃率
parser.add_argument('--num_epochs', type=int, default=100)#训练论数
parser.add_argument('--vocab_size', type=int, default=0)#vocab大小
parser.add_argument('--if_vail', type=bool, default=True)
parser.add_argument('--word2vec_path', type=str, default="/Users/ren/Desktop/nlp相关/glove_to_word2vec.txt")#预训练词向量路径
#parser.add_argument('--word2vec_path', type=str, default="/data/renhongjie/zouye1_new/data/glove_to_word2vec.txt")#预训练词向量路径
parser.add_argument('--save_path', type=str, default="best3.pth")#保存路径
parser.add_argument('--weight_decay', type=float, default=1e-4)#权重衰减
args = parser.parse_args()
if args.if_vail:
    train_iter, test_iter,vail_iter,weight = utils.data(args)
else:
    train_iter, test_iter, weight = utils.data(args)
net=model.ESIM(args,weight=weight)
if args.if_vail:
    train.train(args, device, net,train_iter, test_iter,vail_iter)
else:
    train.train(args,device,train_iter,test_iter,None)
Exemple #3
0
    args.vocab_size = len(text_field.vocab)
    args.target_size = len(label_field.vocab)
    args.weight_matrix = text_field.vocab.vectors
    print(label_field.vocab.itos)
    #print(label_field.vocab.itos)

    print("\nParameters:")
    for attr, value in sorted(args.__dict__.items()):
        print("\t{}={}".format(attr.upper(), value))

    train_iter = data.Iterator(dataset=train_data,
                               batch_size=args.batch_size,
                               shuffle=True)
    valid_iter = data.Iterator(dataset=valid_data,
                               batch_size=args.batch_size,
                               shuffle=False)
    test_iter = data.Iterator(dataset=test_data,
                              batch_size=args.batch_size,
                              shuffle=False)

    if args.train is True:
        print("Start training...")
        esim = model.ESIM(args)
        if args.cuda:
            esim = esim.cuda()
        train.train(train_iter, valid_iter, esim, args)
    else:
        print("\nStart predicting...")
        esim = torch.load("./model/v6.0/model_{}.pkl".format(5)).cuda()
        train.predict(test_iter, esim, args)
        print("Finished.")
Exemple #4
0
           for idx, word in enumerate(w2v_model.wv.index2word)}  # id -> word
feature_pad = 0
label2id = {'neutral': 0, 'entailment': 1, 'contradiction': 2, '-': 3}
id2label = {idx: word for idx, word in enumerate(label2id)}

test_feature1 = [[
    word2id[word] if word in word2id else feature_pad for word in line
] for line in test_feature1_line]
test_feature2 = [[
    word2id[word] if word in word2id else feature_pad for word in line
] for line in test_feature2_line]

embedding_matrix = w2v_model.wv.vectors
input_size, hidden_size = embedding_matrix.shape[0], embedding_matrix.shape[1]
loss_func = torch.nn.CrossEntropyLoss()
net = model.ESIM(input_size, hidden_size, 4, embedding_matrix).to(device)

sentence1_field = Field(sequential=True,
                        use_vocab=False,
                        batch_first=True,
                        fix_length=50,
                        pad_token=feature_pad)
sentence2_field = Field(sequential=True,
                        use_vocab=False,
                        batch_first=True,
                        fix_length=50,
                        pad_token=feature_pad)
fields = [('sentence1', sentence1_field), ('sentence2', sentence2_field)]
#获得测试集的Iterator
test_examples = []
for index in range(len(test_feature_line)):