def validate(args, device_id, pt, step): device = "cpu" if args.visible_gpus == "-1" else "cuda" if (pt != ""): test_from = pt else: test_from = args.test_from logger.info("Loading checkpoint from %s" % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint["opt"]) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) config = BertConfig.from_json_file(args.bert_config_path) model = Summarizer(args, device, load_pretrained_bert=False, bert_config=config) model.load_cp(checkpoint) model.eval() valid_iter = data_loader.Dataloader(args, load_dataset(args, "valid", shuffle=False), args.batch_size, device, shuffle=False, is_test=False) trainer = build_trainer(args, device_id, model, None) stats = trainer.validate(valid_iter, step) return stats.xent()
def test_ext(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = ExtSummarizer(args, device, checkpoint) model.eval() test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, device, shuffle=False, is_test=True) trainer = build_trainer(args, device_id, model, None) trainer.test(test_iter, step)
def train_iter_fct(): return data_loader.Dataloader(args, load_dataset(args, "train", shuffle=True), args.batch_size, device, shuffle=True, is_test=False)
def baseline(args, cal_lead=False, cal_oracle=False): test_iter = data_loader.Dataloader(args, load_dataset(args, "test", shuffle=False), args.batch_size, device, shuffle=False, is_test=True) trainer = build_trainer(args, device_id, None, None) if (cal_lead): trainer.test(test_iter, 0, cal_lead=True) elif (cal_oracle): trainer.test(test_iter, 0, cal_oracle=True)
def baseline(args, cal_lead=False, cal_oracle=False): test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.batch_size, 'cpu', shuffle=False, is_test=True) trainer = build_trainer(args, '-1', None, None, None) # if (cal_lead): trainer.test(test_iter, 0, cal_lead=True) elif (cal_oracle): trainer.test(test_iter, 0, cal_oracle=True)
def validate(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = AbsSummarizer(args, device, checkpoint) model.eval() valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False), args.batch_size, device, shuffle=False, is_test=False) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) symbols = { 'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]'] } valid_loss = abs_loss(model.generator, symbols, model.vocab_size, train=False, device=device) trainer = build_trainer(args, device_id, model, None, valid_loss) stats = trainer.validate(valid_iter, step) return stats.xent()
def test(args, test_from, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if k in model_flags: setattr(args, k, opt[k]) print(args) config = BertConfig.from_json_file(args.bert_config_path) model = Summarizer(args, device, load_pretrained_bert=False, bert_config=config) model.load_cp(checkpoint) model.eval() test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.batch_size, device, shuffle=False, is_test=True) trainer = build_trainer(args, model, None) trainer.test(test_iter, step)
def test_abs(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) print(checkpoint) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) model = AbsSummarizer(args, device, checkpoint) model.eval() test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, device, shuffle=False, is_test=True) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir) symbols = { 'BOS': tokenizer.vocab['[unused0]'], 'EOS': tokenizer.vocab['[unused1]'], 'PAD': tokenizer.vocab['[PAD]'], 'EOQ': tokenizer.vocab['[unused2]'] } predictor = build_predictor(args, tokenizer, symbols, model, logger) predictor.translate(test_iter, step)