コード例 #1
0
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)


    opt = vars(checkpoint['opt'])

    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    spm = sentencepiece.SentencePieceProcessor()
    spm.Load(args.vocab_path)
    word_padding_idx = spm.PieceToId('<PAD>')
    symbols = {'BOS': spm.PieceToId('<S>'), 'EOS': spm.PieceToId('</S>'), 'PAD': word_padding_idx,
               'EOT': spm.PieceToId('<T>'), 'EOP': spm.PieceToId('<P>'), 'EOQ': spm.PieceToId('<Q>')}

    vocab_size = len(spm)
    model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint)
    model.eval()

    valid_iter = data_loader.AbstractiveDataloader(args, load_dataset(args, 'valid', shuffle=False), symbols,
                                                   args.batch_size, device, shuffle=False, is_test=False)

    trainer = build_trainer(args, device_id, model, symbols, vocab_size, None)
    stats = trainer.validate(valid_iter)
    trainer._report_step(0, step, valid_stats=stats)
    return stats.ppl()
コード例 #2
0
ファイル: train_abstractive.py プロジェクト: yaolu/hiersumm
def train(args, device_id):
    init_logger(args.log_file)
    logger.info(str(args))

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        ckpt_step = args.train_from.strip('.pt').split('_')[-1]
    else:
        checkpoint = None

    spm = sentencepiece.SentencePieceProcessor()
    spm.Load(args.vocab_path)
    word_padding_idx = spm.PieceToId('<PAD>')
    symbols = {
        'BOS': spm.PieceToId('<S>'),
        'EOS': spm.PieceToId('</S>'),
        'PAD': word_padding_idx,
        'EOT': spm.PieceToId('<T>'),
        'EOP': spm.PieceToId('<P>'),
        'EOQ': spm.PieceToId('<Q>')
    }
    print(symbols)
    vocab_size = len(spm)

    def train_iter_fct():
        return data_loader.AbstractiveDataloader(args,
                                                 load_dataset(args,
                                                              'train',
                                                              shuffle=True),
                                                 symbols,
                                                 args.batch_size,
                                                 device,
                                                 shuffle=True,
                                                 is_test=False)

    model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint)
    optim = model_builder.build_optim(args, model, checkpoint)
    if args.train_from != '':
        optim._step = int(ckpt_step)
    logger.info(model)
    trainer = build_trainer(args, device_id, model, symbols, vocab_size, optim)

    trainer.train(train_iter_fct, args.train_steps)