Example #1
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)

    opt = vars(checkpoint['opt'])

    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])

    spm = sentencepiece.SentencePieceProcessor()
    word_padding_idx = spm.PieceToId('<PAD>')
    symbols = {'BOS': spm.PieceToId('<S>'), 'EOS': spm.PieceToId('</S>'), 'PAD': word_padding_idx,
               'EOT': spm.PieceToId('<T>'), 'EOP': spm.PieceToId('<P>'), 'EOQ': spm.PieceToId('<Q>')}

    vocab_size = len(spm)
    model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint)

    valid_iter = data_loader.AbstractiveDataloader(args, load_dataset(args, 'valid', shuffle=False), symbols,
                                                   args.batch_size, device, shuffle=False, is_test=False)

    trainer = build_trainer(args, device_id, model, symbols, vocab_size, None)
    stats = trainer.validate(valid_iter)
    trainer._report_step(0, step, valid_stats=stats)
    return stats.ppl()
Example #2
def test(args, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"

    if (pt != ''):
        test_from = pt
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])

    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])

    spm = sentencepiece.SentencePieceProcessor()
    word_padding_idx = spm.PieceToId('<PAD>')
    symbols = {'BOS': spm.PieceToId('<S>'), 'EOS': spm.PieceToId('</S>'), 'PAD': word_padding_idx,
               'EOT': spm.PieceToId('<T>'), 'EOP': spm.PieceToId('<P>'), 'EOQ': spm.PieceToId('<Q>')}

    vocab_size = len(spm)
    vocab = spm
    model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint)

    test_iter = data_loader.AbstractiveDataloader(args, load_dataset(args, args.dataset, shuffle=False), symbols,
                                                  args.valid_batch_size, device, shuffle=False, is_test=True)
    predictor = build_predictor(args, vocab, symbols, model, logger=logger)
    predictor.translate(test_iter, step)
Example #3
def train(args, device_id):

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        ckpt_step = args.train_from.strip('.pt').split('_')[-1]
        checkpoint = None

    spm = sentencepiece.SentencePieceProcessor()
    word_padding_idx = spm.PieceToId('<PAD>')
    symbols = {
        'BOS': spm.PieceToId('<S>'),
        'EOS': spm.PieceToId('</S>'),
        'PAD': word_padding_idx,
        'EOT': spm.PieceToId('<T>'),
        'EOP': spm.PieceToId('<P>'),
        'EOQ': spm.PieceToId('<Q>')
    vocab_size = len(spm)

    def train_iter_fct():
        return data_loader.AbstractiveDataloader(args,

    model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint)
    optim = model_builder.build_optim(args, model, checkpoint)
    if args.train_from != '':
        optim._step = int(ckpt_step)
    trainer = build_trainer(args, device_id, model, symbols, vocab_size, optim)

    trainer.train(train_iter_fct, args.train_steps)
