def infer(batch): batch = utils.move_cuda(batch) return model.infer( batch['x'], args.k, args.penalty, args.alpha, args.step_wise_penalty, args.min_len_a, args.min_len_b, args.max_len_a, args.max_len_b, topk=args.topk )
def validate_logp(model, data_loader): if not utils.is_master(): return None model.eval() num_tokens = 0 total_loss = 0 with model.no_sync() if isinstance(model, DDP) else contextlib.ExitStack(): for i, batch in enumerate(data_loader): batch = utils.move_cuda(batch) logits = model(batch) nll_loss = F.cross_entropy(logits.transpose(2, 1).float(), batch['y'], reduction='sum') total_loss += nll_loss.float().item() num_tokens += batch['true_tokens_y'] return -total_loss / num_tokens
def fb(self, batch: dict, dummy=False): self.model.train() batch = utils.move_cuda(batch) with self._amp_autocast(): loss, nll_loss = self.criterion(self.model(batch), batch['y']) if dummy: loss = loss * 0.0 self.grad_scaler.scale(loss).backward() loss = loss.item() nll_loss = nll_loss.item() if self.args.profile: utils.profile_nan(self.model) return loss, nll_loss
def main(args): logger.info('Loading checkpoints ...') model, vocabularies = load(args.checkpoints, verbose=args.verbose) s_vocab, t_vocab = vocabularies model = utils.move_cuda(model) def infer(batch): batch = utils.move_cuda(batch) return model.infer( batch['x'], args.k, args.penalty, args.alpha, args.step_wise_penalty, args.min_len_a, args.min_len_b, args.max_len_a, args.max_len_b, topk=args.topk ) translator = Translator(infer, args.bpe, args.reverse, args.topk, s_vocab, args.verbose) meter = utils.SpeedMeter() logger.info('Building iterator ...') it = get_iterator(args, s_vocab) n_tok = 0 n_snt = 0 meter.start() logger.info('Start generation ...') for batch in it: translator.translate(batch) n_tok += batch['size_x'] n_snt += len(batch['index']) meter.stop(batch['size_x']) sys.stderr.write( f'Sentences = {n_snt}, Tokens = {n_tok}, \n' f'Time = {datetime.timedelta(seconds=meter.duration)}, \n' f'Speed = {meter.avg:.2f} tok/s, {n_snt / meter.duration:.2f} snt/s\n' )
def validate_bleu(model, data_loader, beam_size: int = 1): if not utils.is_master(): return None model.eval() with model.no_sync() if isinstance(model, DDP) else contextlib.ExitStack(): hyps = [] refs = [] infer = model.module.infer if isinstance(model, DDP) else model.infer for i, batch in enumerate(data_loader): batch = utils.move_cuda(batch) batch_hyps = infer(batch['x'], beam_size) batch_hyps = [k_best[0] for k_best in batch_hyps] hyps.extend([h['tokens'] for h in batch_hyps]) refs.extend(batch['refs']) results = {'hyp': hyps, 'ref': refs} bleu = utils.bleu(results['ref'], results['hyp']) return bleu
def run(args): disable_numba_logging() if utils.is_master(): with Path(args.checkpoint, 'config.json').open('w') as w: w.write(json.dumps(args.__dict__, indent=4, sort_keys=True)) utils.seed(args.seed) sv, tv = load_vocab(args) logger.info('Building model') model = utils.move_cuda(models.build(args, [sv, tv])) criterion = utils.move_cuda(losses.CrossEntropyLoss(args.label_smoothing, ignore_index=tv.pad_id)) optimizer = utils.move_cuda(optim.build(args, model.parameters())) lr_scheduler = lr_schedulers.build(args, args.lr, optimizer) state_dict = utils.checkpoint.load_latest(args.checkpoint) if not state_dict: logger.info(f'Model: \n{model}') elif args.finetune: raise ValueError(f'fine-tuning is not available while trying to resume training from an existing checkpoint.') logger.info(f'FP16: {args.fp16}') if args.fp16 and args.fp16 != 'none': if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] < 7: logger.warning('Target device does not support acceleration with --fp16') if args.fp16 == 'half': model = model.half() criterion = criterion.half() train_itr = build_train_itr(args, sv, tv) trainer = Trainer(args, model, criterion, [sv, tv], optimizer, lr_scheduler, train_itr) logger.info(f'Max sentences = {args.max_sentences}, ' f'max tokens = {args.max_tokens} ') state_dict = state_dict or try_prepare_finetune(args, model) stat_parameters(model) # Restore training process if state_dict: logger.info('Resuming from given checkpoint') trainer.load_state_dict(state_dict, no_load_model=bool(args.finetune)) del state_dict if dist.is_initialized(): model = DDP( model, device_ids=[torch.cuda.current_device()], output_device=torch.cuda.current_device(), find_unused_parameters=True ) trainer.model = model check_devices(args.dist_world_size) def after_epoch_callback(): logger.info(f'Finished epoch {trainer.epoch}. ') try: from torch.utils.tensorboard import SummaryWriter except ImportError: logger.warning('Tensorboard is not available.') SummaryWriter = utils.DummySummaryWriter if not utils.is_master(): SummaryWriter = utils.DummySummaryWriter writer = SummaryWriter(log_dir=Path(args.checkpoint, 'tensorboard'), purge_step=trainer.global_step or None) validate_callback = None if args.dev: validate_callback = get_validator(args, model, sv, tv) trainer.timer.start() with writer: with torch.autograd.profiler.record_function('train_loop'): trainer.train(validate_callback=validate_callback, before_epoch_callback=None, after_epoch_callback=after_epoch_callback, summary_writer=writer) logger.info(f'Training finished @ {time.strftime("%b %d, %Y, %H:%M:%S", time.localtime())}, ' f'took {datetime.timedelta(seconds=trainer.elapse // 1)}') logger.info(f'Best validation score: {trainer.best_score}, @ {trainer.best_at}')