def validate(args, device_id, pt, step): device = "cpu" if args.visible_gpus == '-1' else "cuda" if (pt != ''): test_from = pt else: test_from = args.test_from logger.info('Loading checkpoint from %s' % test_from) checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) print(args) spm = sentencepiece.SentencePieceProcessor() spm.Load(args.vocab_path) word_padding_idx = spm.PieceToId('<PAD>') symbols = {'BOS': spm.PieceToId('<S>'), 'EOS': spm.PieceToId('</S>'), 'PAD': word_padding_idx, 'EOT': spm.PieceToId('<T>'), 'EOP': spm.PieceToId('<P>'), 'EOQ': spm.PieceToId('<Q>')} vocab_size = len(spm) model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint) model.eval() valid_iter = data_loader.AbstractiveDataloader(args, load_dataset(args, 'valid', shuffle=False), symbols, args.batch_size, device, shuffle=False, is_test=False) trainer = build_trainer(args, device_id, model, symbols, vocab_size, None) stats = trainer.validate(valid_iter) trainer._report_step(0, step, valid_stats=stats) return stats.ppl()
def train(args, device_id): init_logger(args.log_file) logger.info(str(args)) device = "cpu" if args.visible_gpus == '-1' else "cuda" logger.info('Device ID %d' % device_id) logger.info('Device %s' % device) torch.manual_seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True if device_id >= 0: torch.cuda.set_device(device_id) torch.cuda.manual_seed(args.seed) if args.train_from != '': logger.info('Loading checkpoint from %s' % args.train_from) checkpoint = torch.load(args.train_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) ckpt_step = args.train_from.strip('.pt').split('_')[-1] else: checkpoint = None spm = sentencepiece.SentencePieceProcessor() spm.Load(args.vocab_path) word_padding_idx = spm.PieceToId('<PAD>') symbols = { 'BOS': spm.PieceToId('<S>'), 'EOS': spm.PieceToId('</S>'), 'PAD': word_padding_idx, 'EOT': spm.PieceToId('<T>'), 'EOP': spm.PieceToId('<P>'), 'EOQ': spm.PieceToId('<Q>') } print(symbols) vocab_size = len(spm) def train_iter_fct(): return data_loader.AbstractiveDataloader(args, load_dataset(args, 'train', shuffle=True), symbols, args.batch_size, device, shuffle=True, is_test=False) model = Summarizer(args, word_padding_idx, vocab_size, device, checkpoint) optim = model_builder.build_optim(args, model, checkpoint) if args.train_from != '': optim._step = int(ckpt_step) logger.info(model) trainer = build_trainer(args, device_id, model, symbols, vocab_size, optim) trainer.train(train_iter_fct, args.train_steps)