示例#1
0
def load_fields():
    fields = IO.load_fields(
                torch.load(args.vocab))
    # fields = dict([(k, f) for (k, f) in fields.items()
    #               if k in train.examples[0].__dict__])
    # train.fields = fields

    print(' * vocabulary size. source = %d; target = %d; tag = %d' %
          (len(fields['src'].vocab), len(fields['tgt'].vocab), len(fields['tag'].vocab)))

    return fields
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("-test_data", type=str)
    parser.add_argument("-test_out", type=str)
    parser.add_argument("-seq2seq", type=str)
    parser.add_argument("-sampler", type=str)
    parser.add_argument("-vocab", type=str)
    parser.add_argument("-config", type=str)
    parser.add_argument("-dump_beam", default="", type=str)
    parser.add_argument('-gpuid', default=[], nargs='+', type=int)
    parser.add_argument("-beam_size", type=int)
    parser.add_argument("-topk_tag", type=int)
    parser.add_argument("-decode_max_length", type=int)
    parser.add_argument("-num_cluster", type=int)

    parser.add_argument("-tensorboard", type=str, default="")
    args = parser.parse_args()
    config = utils.load_config(args.config)

    use_cuda = False
    if args.gpuid:
        cuda.set_device(args.gpuid[0])
        use_cuda = True
    fields = IO.load_fields(torch.load(args.vocab))

    infer_dataset = InferDataset(data_path=args.test_data,
                                 fields=[('src', fields["src"])])

    data_iter = IO.OrderedIterator(dataset=infer_dataset,
                                   device=args.gpuid[0],
                                   batch_size=1,
                                   train=False,
                                   sort=False,
                                   sort_within_batch=True,
                                   shuffle=False)
    # Build model.

    seq2seq_model = create_seq2seq_tag_model(config, fields)
    sampler_model = create_tag_sampler(config, fields)

    print('Loading parameters ...')
    if args.seq2seq and args.sampler:
        seq2seq_model.load_checkpoint(args.seq2seq)
        sampler_model.load_checkpoint(args.sampler)
    if use_cuda:
        seq2seq_model = seq2seq_model.cuda()
        sampler_model = sampler_model.cuda()

    infer = Infer(seq2seq_model=seq2seq_model,
                  sampler_model=sampler_model,
                  fields=fields,
                  beam_size=args.beam_size,
                  n_best=1,
                  max_length=args.decode_max_length,
                  global_scorer=None,
                  cuda=use_cuda,
                  beam_trace=True if args.dump_beam else False)
    writer = None
    if args.tensorboard:
        writer = open(args.tensorboard, 'w', encoding='utf8')

    inference_file(infer, data_iter, args.test_out, fields, args.topk_tag,
                   args.num_cluster, use_cuda, writer)
    writer.close()
示例#3
0
def make_train_data_iter(train_data, config):
    return IO.OrderedIterator(
                dataset=train_data, batch_size=config['Seq2Seq']['Trainer']['batch_size'],
                device=args.gpuid[0] if args.gpuid else -1,
                repeat=False)
示例#4
0
import dialog0.Seq2Seq.IO as IO
import argparse
import torch
import dialog0.Utils as utils
from dialog0.Seq2SeqWithRL.Dataset import RLDataset
from dialog0.Seq2Seq.Dataset import SeqDataset

parser = argparse.ArgumentParser()
parser.add_argument('-train_data', type=str)
parser.add_argument('-save_data', type=str)
parser.add_argument('-config', type=str)
args = parser.parse_args()

config = utils.load_config(args.config)

if config['Misc']['random_seed'] > 0:
    torch.manual_seed(config['Misc']['random_seed'])

fields = IO.get_fields()
print("Building Training...")
train = SeqDataset(data_path=args.train_data,
                   fields=[('src', fields["src"]), ('tgt', fields["tgt"]),
                           ('tag', fields["tag"])])
print("Building Vocab...")
IO.build_vocab(train, config)

print("Saving fields")
torch.save(IO.save_vocab(fields), open(args.save_data + '.vocab.pkl', 'wb'))
# train.fields = []
# torch.save(train, open(args.save_data+'.train.pkl', 'wb'))