Пример #1
0
    .split(','))
log_template = ' '.join(
    '{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:12.4f},{}'
    .split(','))
best_snapshot_prefix = os.path.join(args.save_path, 'best_snapshot')
print(header)

for epoch in range(1, args.epochs + 1):
    if early_stop:
        print("Early stopping. Epoch: {}, Best Dev. Acc: {}".format(
            epoch, best_dev_acc))
        break

    n_correct, n_total = 0, 0

    for batch_idx, batch in enumerate(train_loader.next_batch()):
        iterations += 1
        model.train()
        optimizer.zero_grad()

        pos_score, neg_score = model(batch)

        n_correct += (torch.sum(torch.gt(pos_score, neg_score),
                                0).data == neg_score.size(0)).sum()
        n_total += pos_score.size(1)
        train_acc = 100. * n_correct / n_total

        ones = torch.autograd.Variable(
            torch.ones(pos_score.size(0) * pos_score.size(1)))
        if args.cuda:
            ones = ones.cuda()
Пример #2
0
#-*- coding: utf-8 -*-

# Author: QuYingqi
# mail: [email protected]
# Created Time: 2017-11-09
import sys
import torch
from seqRankingLoader import SeqRankingLoader
import numpy as np
sys.path.append('../vocab')

rel_vocab = torch.load('../vocab/vocab.rel.pt')
neg_range = len(rel_vocab)
print(neg_range)
word_vocab = torch.load('../vocab/vocab.word.pt')
loader = SeqRankingLoader('data/test.relation_ranking.pt', 5, neg_range, 0)
batch_size = loader.batch_size
for i, batch in enumerate(loader.next_batch(False)):
    if i >= 1: break
    seqs, pos_rel, neg_rel = batch
    seqs_trans = np.transpose(seqs.cpu().data.numpy())
    pos_rel_trans = pos_rel.cpu().data.numpy()
    neg_rel_trans = np.transpose(neg_rel.cpu().data.numpy())
    for j in range(5):
        question = ' '.join(word_vocab.convert_to_word(seqs_trans[j]))
        print(question)
        pos_rel_ = rel_vocab.convert_to_word([pos_rel_trans[j]])
        print(pos_rel_)
        neg_rel_ = ' | '.join(rel_vocab.convert_to_word(neg_rel_trans[j]))
        print(neg_rel_)