示例#1
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("-vocab", type=str)
    parser.add_argument("-model", type=str)
    parser.add_argument("-config", type=str)

    args = parser.parse_args()
    config = utils.load_config(args.config)

    use_cuda = False
    if args.gpuid:
        cuda.set_device(args.gpuid[0])
        use_cuda = True
    fields = IO.load_fields(torch.load(args.vocab))

    # Build model.

    seq2seq_model = create_seq2seq_tag_model(config, fields)
    sampler_model = create_tag_sampler(config, fields)
    model = create_seq2seq_rl_model(config, fields, sampler_model,
                                    seq2seq_model)

    print('Loading parameters ...')
    if args.model:
        model.load_checkpoint(args.model)
    if use_cuda:
        model = model.cuda()
示例#2
0
def build_or_load_model(vocab, index_2_latent_sentence, device):
    print('Building model...')
    model = create_sampler_model(
        vocab=vocab,
        index_2_latent_sentence=index_2_latent_sentence,
        device=device)

    latest_ckpt = utils.latest_checkpoint(config.out_dir)
    start_epoch_at = 0

    start_epoch_at = model.load_checkpoint(latest_ckpt)

    print('\n')
    print(model)
    print('\n')

    return model
示例#3
0
def build_or_load_model(vocab, device):

    model = create_seq2seq_model(vocab=vocab, device=device)
    latest_ckpt = utils.latest_checkpoint(config.out_dir)

    ckpt = latest_ckpt
    # latest_ckpt = nmt.misc_utils.latest_checkpoint(model_dir)
    if ckpt:
        print('Loding model from %s...'%(ckpt))
        start_epoch_at = model.load_checkpoint(ckpt)
    else:
        print('Building model...')

    print('\n')
    print(model)
    print('\n')

    return model
示例#4
0
def build_or_load_model(config, fields):

    model = create_seq2seq_tag_model(config,fields)
    latest_ckpt = utils.latest_checkpoint(config['Seq2Seq']['Trainer']['out_dir'])
    start_epoch_at = 0
    if config['Seq2Seq']['Trainer']['start_epoch_at'] is not None:
        ckpt = 'checkpoint_epoch%d.pkl'%(config['Seq2Seq']['Trainer']['start_epoch_at'])
        ckpt = os.path.join(config['Seq2Seq']['Trainer']['out_dir'],ckpt)
    else:
        ckpt = latest_ckpt
    # latest_ckpt = nmt.misc_utils.latest_checkpoint(model_dir)
    if ckpt:
        print('Loding model from %s...'%(ckpt))
        start_epoch_at = model.load_checkpoint(ckpt)
    else:
        print('Building model...')
    print(model)

    return model, start_epoch_at
示例#5
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("-test_data", type=str)
    parser.add_argument("-test_out", type=str)
    parser.add_argument("-vocab", type=str)
    parser.add_argument("-model", type=str)
    parser.add_argument("-config", type=str)
    parser.add_argument("-dump_beam", default="", type=str)
    parser.add_argument('-gpuid', default=[], nargs='+', type=int)
    parser.add_argument("-beam_size", type=int)
    parser.add_argument("-topk_tag", type=int)
    parser.add_argument("-num_cluster", type=int)
    parser.add_argument("-decode_max_length", type=int)
    args = parser.parse_args()
    config = utils.load_config(args.config)

    use_cuda = False
    if args.gpuid:
        cuda.set_device(args.gpuid[0])
        use_cuda = True
    fields = IO.load_fields(
                torch.load(args.vocab))

    infer_dataset = InferDataset(
        data_path=args.test_data,
        fields=[('src', fields["src"])])

    data_iter = IO.InferIterator(
                dataset=infer_dataset, device=args.gpuid[0],
                batch_size=1, train=False, sort=False,
                sort_within_batch=True, shuffle=False)
    # Build model.

    seq2seq_model = create_seq2seq_tag_model(config, fields)
    sampler_model = create_tag_sampler(config, fields)
    model = create_seq2seq_rl_model(config, fields, sampler_model, seq2seq_model)


    print('Loading parameters ...')
    if args.model:
        model.load_checkpoint(args.model)
    if use_cuda:
        model = model.cuda()

    infer = Infer(model=model, 
                    fields=fields,
                    beam_size=args.beam_size, 
                    n_best=1,
                    max_length=args.decode_max_length,
                    global_scorer=None,
                    cuda=use_cuda,
                    beam_trace=True if args.dump_beam else False)


    writer = None
    if args.dump_beam:
        writer = open(args.dump_beam,'w',encoding='utf8')


    inference_file(infer, data_iter, args.test_out, fields, args.topk_tag, args.num_cluster, use_cuda, writer)
示例#6
0
from dialog0.Seq2Seq.ModelHelper import create_seq2seq_tag_model
from dialog0.TagSampler.ModelHelper import create_tag_sampler
from dialog0.Seq2SeqWithRL.Dataset import RLDataset
from tensorboardX import SummaryWriter
parser = argparse.ArgumentParser()
parser.add_argument('-train_data', type=str)
parser.add_argument("-config", type=str)
parser.add_argument("-vocab", type=str)
parser.add_argument("-seq2seq", type=str)
parser.add_argument("-sampler", type=str)
parser.add_argument("-log_dir", type=str)
parser.add_argument('-gpuid', default=[], nargs='+', type=int)


args = parser.parse_args()
config = utils.load_config(args.config)
if args.gpuid:
    cuda.set_device(args.gpuid[0])

summery_writer = SummaryWriter(args.log_dir)

def report_func(global_step, epoch, batch, num_batches,
                start_time, lr, report_stats):
    """
    This is the user-defined batch-level traing progress
    report function.
    Args:
        epoch(int): current epoch count.
        batch(int): current batch count.
        num_batches(int): total number of batches.
        start_time(float): last report time.