Ejemplo n.º 1
0
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == "-1" else "cuda"
    if (pt != ""):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info("Loading checkpoint from %s" % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint["opt"])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    config = BertConfig.from_json_file(args.bert_config_path)
    model = Summarizer(args,
                       device,
                       load_pretrained_bert=False,
                       bert_config=config)
    model.load_cp(checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(args,
                                        load_dataset(args,
                                                     "valid",
                                                     shuffle=False),
                                        args.batch_size,
                                        device,
                                        shuffle=False,
                                        is_test=False)
    trainer = build_trainer(args, device_id, model, None)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
Ejemplo n.º 2
0
def test_ext(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    model = ExtSummarizer(args, device, checkpoint)
    model.eval()

    test_iter = data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    'test',
                                                    shuffle=False),
                                       args.test_batch_size,
                                       device,
                                       shuffle=False,
                                       is_test=True)
    trainer = build_trainer(args, device_id, model, None)
    trainer.test(test_iter, step)
Ejemplo n.º 3
0
 def train_iter_fct():
     return data_loader.Dataloader(args,
                                   load_dataset(args, "train",
                                                shuffle=True),
                                   args.batch_size,
                                   device,
                                   shuffle=True,
                                   is_test=False)
Ejemplo n.º 4
0
def baseline(args, cal_lead=False, cal_oracle=False):
    test_iter = data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    "test",
                                                    shuffle=False),
                                       args.batch_size,
                                       device,
                                       shuffle=False,
                                       is_test=True)

    trainer = build_trainer(args, device_id, None, None)

    if (cal_lead):
        trainer.test(test_iter, 0, cal_lead=True)
    elif (cal_oracle):
        trainer.test(test_iter, 0, cal_oracle=True)
Ejemplo n.º 5
0
def baseline(args, cal_lead=False, cal_oracle=False):
    test_iter = data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    'test',
                                                    shuffle=False),
                                       args.batch_size,
                                       'cpu',
                                       shuffle=False,
                                       is_test=True)

    trainer = build_trainer(args, '-1', None, None, None)
    #
    if (cal_lead):
        trainer.test(test_iter, 0, cal_lead=True)
    elif (cal_oracle):
        trainer.test(test_iter, 0, cal_oracle=True)
Ejemplo n.º 6
0
def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    model = AbsSummarizer(args, device, checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(args,
                                        load_dataset(args,
                                                     'valid',
                                                     shuffle=False),
                                        args.batch_size,
                                        device,
                                        shuffle=False,
                                        is_test=False)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                              do_lower_case=True,
                                              cache_dir=args.temp_dir)
    symbols = {
        'BOS': tokenizer.vocab['[unused0]'],
        'EOS': tokenizer.vocab['[unused1]'],
        'PAD': tokenizer.vocab['[PAD]'],
        'EOQ': tokenizer.vocab['[unused2]']
    }

    valid_loss = abs_loss(model.generator,
                          symbols,
                          model.vocab_size,
                          train=False,
                          device=device)

    trainer = build_trainer(args, device_id, model, None, valid_loss)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
Ejemplo n.º 7
0
def test(args, test_from, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"

    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if k in model_flags:
            setattr(args, k, opt[k])
    print(args)

    config = BertConfig.from_json_file(args.bert_config_path)
    model = Summarizer(args, device, load_pretrained_bert=False, bert_config=config)
    model.load_cp(checkpoint)
    model.eval()


    test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
                                       args.batch_size, device,
                                       shuffle=False, is_test=True)
    trainer = build_trainer(args, model, None)
    trainer.test(test_iter, step)
Ejemplo n.º 8
0
def test_abs(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)

    checkpoint = torch.load(test_from,
                            map_location=lambda storage, loc: storage)
    print(checkpoint)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    model = AbsSummarizer(args, device, checkpoint)
    model.eval()

    test_iter = data_loader.Dataloader(args,
                                       load_dataset(args,
                                                    'test',
                                                    shuffle=False),
                                       args.test_batch_size,
                                       device,
                                       shuffle=False,
                                       is_test=True)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                              do_lower_case=True,
                                              cache_dir=args.temp_dir)
    symbols = {
        'BOS': tokenizer.vocab['[unused0]'],
        'EOS': tokenizer.vocab['[unused1]'],
        'PAD': tokenizer.vocab['[PAD]'],
        'EOQ': tokenizer.vocab['[unused2]']
    }
    predictor = build_predictor(args, tokenizer, symbols, model, logger)
    predictor.translate(test_iter, step)