def test_text_abs(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = AbsSummarizer(args, device, checkpoint) model.eval() test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, device, shuffle=False, is_test=True) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']} predictor = build_predictor(args, tokenizer, symbols, model, logger) predictor.translate(test_iter, step)
def validate(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = AbsSummarizer(args, device, checkpoint) model.eval() valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False), args.batch_size, device, shuffle=False, is_test=False) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) symbols = {'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]']} valid_loss = abs_loss(model.generator, symbols, model.vocab_size, train=False, device=device) trainer = build_trainer(args, device_id, model, None, valid_loss) stats = trainer.validate(valid_iter, step) return stats.xent()
def test_ext(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = ExtSummarizer(args, device, checkpoint) model.eval() test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, device, shuffle=False, is_test=True) trainer = build_trainer(args, device_id, model, None) trainer.test(test_iter, step)
def train_iter_fct(): return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device, shuffle=True, is_test=False)
def baseline(args, cal_lead=False, cal_oracle=False): test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.batch_size, 'cpu', shuffle=False, is_test=True) trainer = build_trainer(args, '-1', None, None, None) # if (cal_lead): trainer.test(test_iter, 0, cal_lead=True) elif (cal_oracle): trainer.test(test_iter, 0, cal_oracle=True)