def test_grouped_iterator(self): # test correctness x = list(range(10)) itr = iterators.GroupedIterator(x, 1) self.assertEqual(list(itr), [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]) itr = iterators.GroupedIterator(x, 4) self.assertEqual(list(itr), [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]]) itr = iterators.GroupedIterator(x, 5) self.assertEqual(list(itr), [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) # test the GroupIterator also works correctly as a CountingIterator x = list(range(30)) ref = list(iterators.GroupedIterator(x, 3)) itr = iterators.GroupedIterator(x, 3) self.test_counting_iterator_index(ref, itr)
def test_grouped_iterator(self): # test correctness x = list(range(10)) itr = iterators.GroupedIterator(x, 1) self.assertEqual(list(itr), [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]) itr = iterators.GroupedIterator(x, 4) self.assertEqual(list(itr), [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]]) itr = iterators.GroupedIterator(x, 5) self.assertEqual(list(itr), [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) # test CountingIterator functionality x = list(range(30)) ref = list(iterators.GroupedIterator(x, 3)) itr = iterators.GroupedIterator(x, 3) self.test_counting_iterator(ref, itr)
def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function("train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) if should_stop: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def test_nmt(args, trainer, task, epoch_itr): # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) valid_subsets = ['valid'] max_update = args.max_update or math.inf num_samples = 0 for samples in progress: for i, sample in enumerate(samples): total_loss = trainer.valid_step(sample) num_samples += 1 #num_updates = trainer.get_num_updates() return num_samples / total_loss
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')
def _estimate_diagonal_fisher(args, trainer, epoch_itr, n_steps): """Estimate the diagonal empirical fisher information matrix""" # Iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=True) itr = iterators.GroupedIterator(itr, 1, bottomless=True) progress = progress_bar.build_progress_bar( args, itr, 0, no_progress_bar='simple', ) progress.log_interval = n_steps // 10 # Initialize the Fisher FIM = { name: th.zeros_like(p) for name, p in trainer.model.named_parameters() } # Iterate for i, samples in enumerate(islice(progress, n_steps)): # Forward backward trainer.train_step(samples, update_params=False, clip_grad=False) # Get gradients for name, p in trainer.model.named_parameters(): FIM[name].add_(p.grad.detach()**2) # Log progress progress.log({"step": i}) # Normalize FIM = {name: F / n_steps for name, F in FIM.items()} return FIM
def train(args, trainer, task, epoch_itr, max_update=math.inf): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, 'tpu', False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) progress.log_args(args, tag='train') trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(',') for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') end_of_epoch = not itr.has_next() valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return valid_losses
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf if hasattr(trainer.criterion, 'set_epoch'): trainer.criterion.set_epoch(epoch_itr.epoch) for samples in progress: if hasattr(trainer.criterion, 'set_num_updates'): trainer.criterion.set_num_updates(trainer.get_num_updates()) log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() if log_output is None: continue # log mid-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.log(stats, tag='train', step=num_updates) if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')
def test_grouped_iterator_skip_remainder_batch(self): reference = [1, 2, 3, 4, 5, 6, 7, 8, 9] itr1 = _get_epoch_batch_itr(reference, 3, False) grouped_itr1 = iterators.GroupedIterator(itr1, 2, True) self.assertEqual(len(grouped_itr1), 1) itr2 = _get_epoch_batch_itr(reference, 3, False) grouped_itr2 = iterators.GroupedIterator(itr2, 2, False) self.assertEqual(len(grouped_itr2), 2) itr3 = _get_epoch_batch_itr(reference, 3, True) grouped_itr3 = iterators.GroupedIterator(itr3, 2, True) self.assertEqual(len(grouped_itr3), 1) itr4 = _get_epoch_batch_itr(reference, 3, True) grouped_itr4 = iterators.GroupedIterator(itr4, 2, False) self.assertEqual(len(grouped_itr4), 1) itr5 = _get_epoch_batch_itr(reference, 5, True) grouped_itr5 = iterators.GroupedIterator(itr5, 2, True) self.assertEqual(len(grouped_itr5), 0) itr6 = _get_epoch_batch_itr(reference, 5, True) grouped_itr6 = iterators.GroupedIterator(itr6, 2, False) self.assertEqual(len(grouped_itr6), 1)
def downstream_train_pytorch(args, trainer, task, epoch_itr, train_prefix): """Fine-tune PyTorch classifier on downstream training set for one epoch""" task.split = 'train' num_updates = trainer.get_num_updates() # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) progress = maybe_wrap_neptune_logging(progress, args) # Task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) max_update = args.max_update or math.inf with metrics.aggregate() as agg: for samples in progress: # Train for one step log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() if log_output is None: continue # log mid-epoch stats stats = get_ft_train_stats(agg.get_smoothed_values()) progress.log(stats, tag=train_prefix, step=num_updates) if num_updates >= max_update: break # log end-of-epoch stats stats = get_ft_train_stats(agg.get_smoothed_values()) try: progress.print(stats, tag=train_prefix, step=num_updates, log=False) except: progress.print(stats, tag=train_prefix, step=num_updates) # Reset epoch-level meters metrics.reset_meters(train_prefix)
def initialize_loader_for_epoch(args, epoch_itr): if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=False, shuffle=(epoch_itr.epoch >= args.curriculum)) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple') return progress
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" task.split = 'train' # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = maybe_wrap_neptune_logging( progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ), args=args, ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf with metrics.aggregate() as agg: for samples in progress: log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() if log_output is None: continue # log mid-epoch stats stats = get_training_stats(agg.get_smoothed_values()) progress.log(stats, tag='train', step=num_updates) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(agg.get_smoothed_values()) try: progress.print(stats, tag='train', step=num_updates, log=False) except: progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')
def fisher(args, trainer, epoch_itr): if args.no_fisher: # Keep training code untouched, and make the Fisher values 1s. for n, p in trainer.model.named_parameters(): trainer.fisher[n] = torch.ones(p.shape, device=p.device) for n, _ in trainer.model.named_parameters(): trainer.fisher[n] = torch.autograd.Variable(trainer.fisher[n], requires_grad=False) return # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) for n, p in trainer.model.named_parameters(): trainer.fisher[n] = 0 * p.data for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): trainer.fisher_step(samples) for n, _ in trainer.model.named_parameters(): trainer.fisher[n] = trainer.fisher[n] / len(progress) trainer.fisher[n] = torch.autograd.Variable(trainer.fisher[n], requires_grad=False) for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def start_epoch(self) -> bool: if not (self.trainer.get_lr() > self.args.min_lr and self.epoch_itr.epoch < self.max_epoch and self.trainer.get_num_updates() < self.max_update): self._done = True return False args = self.args update_freq = args.update_freq[self.epoch_itr.epoch - 1] \ if self.epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = self.epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(self.epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) # meters in the epoch self.extra_meters = collections.defaultdict(lambda: AverageMeter()) # enumerate self.itr = enumerate(itr, start=self.epoch_itr.iterations_in_epoch) return True
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches update_freq = 1 # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=False, # TODO: changed ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) progress.log(stats, tag='train', step=stats['num_updates']) stats = get_training_stats(trainer) progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip' ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train_nmt(args, trainer, task, epoch_itr): # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) valid_subsets = ['valid'] max_update = args.max_update or math.inf for samples in progress: with fmetrics.aggregate('train_inner'): log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() if log_output is None: continue # log mid-epoch stats #stats = get_training_stats('train_inner') #progress.log(stats, tag='train', step=num_updates) if ( not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0 ): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) return log_output
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): try: log_output = trainer.train_step(samples) except FloatingPointError as e: if "Minimum loss scale reached" in str(e): print(f'Check samples: len={len(samples)}') for ik, s in enumerate(samples): if s is None: print(f'[{ik}]: None') else: for k, v in s.items(): if isinstance(v, torch.Tensor): print(f'[{ik}][{k}]: {v.size()}') raise e if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, trainer, task, epoch_itr, summary_writer=None): """Train the model for one epoch.""" # Update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf num_batches = len(epoch_itr) distributed_utils.barrier(args, "train_%d" % trainer.get_num_updates()) for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg stats['progress'] = round( i / num_batches * args.distributed_world_size * args.update_freq[-1], 3) progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) distributed_utils.barrier( args, "train_val_%d" % trainer.get_num_updates()) if num_updates % args.log_interval == 0: summary_writer.log_stats('train', stats, num_updates) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf should_end_training = False for samples in progress: with metrics.aggregate('train_inner'): try: log_output = trainer.train_step(samples) except ResetTrainerException: trainer._wrapped_criterion = None trainer._wrapped_model = None trainer._optimizer = None logger.info("reset the trainer at {}".format( trainer.get_num_updates())) log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: should_end_training = True break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return should_end_training
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" if type( task ) is tasks.factored_translation.FactoredTranslationTask: # factored if args.factors_to_freeze is not None: factors_to_freeze = list({ x for lang_pair in [args.factors_to_freeze] for x in lang_pair.split(',') }) if epoch_itr.epoch == args.freeze_factors_epoch: for factor in factors_to_freeze: print('Freezing', factor) for param in trainer.get_model( ).encoder.encoders[factor].parameters(): param.requires_grad = False # Update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, trainer, task, epoch_itr, generator=None, filtered_maxpos_indices=None): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf # data selection: reset epoch iter to filter out unselected data filter_data = epoch_itr.epoch % args.select_by_dds_epoch == 0 if filter_data and args.select_by_dds_epoch > 0: epoch_itr, _ = trainer.get_filtered_train_iterator( epoch_itr.epoch, filtered_maxpos_indices=filtered_maxpos_indices) # if args.update_language_sampling > 0 and args.select_by_dds_epoch < 0 and (not args.data_actor_step_update): # num_reset = len(epoch_itr.frozen_batches) // (args.update_language_sampling*args.update_freq[0]+1) # datasize = args.update_language_sampling*args.update_freq[0]+1 # if num_reset * datasize < len(epoch_itr.frozen_batches): # num_reset += 1 # else: # num_reset = 1 # datasize = -1 # for reset_idx in range(num_reset): # print("resetting at step", reset_idx) # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), offset=0, datasize=-1, ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): #print(samples) # if args.extra_data_actor == 'ave_emb': # update_actor = (i % args.extra_update_language_sampling == 0) # elif args.data_actor_step_update: # update_actor = (i % args.update_language_sampling == 0) # elif args.data_actor == 'lan' and args.data_actor_step_update: # update_actor = (i % args.update_language_sampling == 0) # else: # update_actor = False # update sampling distribution # if args.update_language_sampling > 0 and i % args.update_language_sampling == 0 and args.data_actor != 'ave_emb' and not args.data_actor_step_update: # if args.data_actor_multilin: # trainer.update_language_sampler_multilin(args, epoch=epoch_itr.epoch) # else: # trainer.update_language_sampler(args) if (epoch_itr.epoch > args.select_by_dds_epoch and args.select_by_dds_epoch > 0): update_actor = False update_actor = False log_output = trainer.train_step(samples, update_actor=update_actor) if log_output is None: continue # update the data selector if args.select_by_dds_epoch > 0 and args.update_data_selector > 0 and i % args.update_data_selector == 0: trainer.update_data_selector(args) # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets, generator) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset() return epoch_itr
def train(args, trainer, task, epoch_itr, epoch_aux_itr, fim=None): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] print(update_freq) # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) # Auxiliary iterator aux_itr = epoch_aux_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus) aux_itr = iterators.GroupedIterator(aux_itr, update_freq, bottomless=True) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): # Record gradients from auxiliary data aux_samples = next(aux_itr) trainer.train_step(aux_samples, update_params=False) # Fisher if hasattr(trainer.optimizer, "save_auxiliary"): trainer.optimizer.save_auxiliary() else: print("Warning, the optimizer is ignoring the auxiliary gradients") # Take a step on the primary task log_output = trainer.train_step(samples, apply_ewc=args.ewc > 0) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, None) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, trainer, task, epoch_itr, experiment=None): """Train the model for one epoch.""" # Update parameters every N batches update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar="simple") extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(",") max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ "loss", "nll_loss", "ntokens", "nsentences", "sample_size" ]: continue # these are already logged above if "loss" in k or k == "accuracy": extra_meters[k].update(v, log_output["sample_size"]) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag="train", step=stats["num_updates"]) if experiment: experiment.log_metrics(stats, step=stats["num_updates"], prefix="mid_epoch_train") # ignore the first mini-batch in words-per-second and updates-per-second calculation if i == 0: trainer.get_meter("wps").reset() trainer.get_meter("ups").reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag="train", step=stats["num_updates"]) if experiment: experiment.log_metrics(stats, prefix="end_of_epoch_train", step=stats["num_updates"]) # reset training meters for k in [ "train_loss", "train_nll_loss", "wps", "ups", "wpb", "bsz", "gnorm", "clip", ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, trainer, task, epoch_itr, force_refine_step=None): """Train the model for one epoch.""" # Update parameters every N batches def is_better(a, b): return a > b if args.maximize_best_checkpoint_metric else a < b update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf if hasattr(args, "progressive") and args.progressive: task.dataset("train").set_random_refine_step( args.refinetot, force_refine_step=force_refine_step) last_samples = None for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): if samples is None or len(samples) == 0: sys.stderr.write("Empty sample detected\n") sys.stderr.flush() samples = last_samples else: last_samples = samples log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets, force_refine_step=force_refine_step) # if distributed_utils.is_master(args): # print("saving:", trainer.get_num_updates()) # nsml.save(str(trainer.get_num_updates())) if not hasattr(checkpoint_utils.save_checkpoint, 'best') or is_better( valid_losses[0], checkpoint_utils.save_checkpoint.best): if distributed_utils.is_master(args): print("saving checkpoint ...") sys.stdout.flush() if HAS_NSML: nsml.save("best") else: torch.save({"model": trainer.get_model().state_dict()}, "/tmp/best.pt") if HAS_WANDB: wandb.save("/tmp/best.pt") sys.stdout.flush() checkpoint_utils.save_checkpoint.best = valid_losses[0] if args.decoder_wise_training and update_num_to_refine_step( num_updates) != force_refine_step: if HAS_NSML: nsml.load("best") else: # Retrieve the model if distributed_utils.is_master(args): state = torch.load("/tmp/best.pt", map_location="cpu") trainer.model.load_state_dict(state["model"]) # Sync assert isinstance(trainer.model, parallel.DistributedDataParallel) if isinstance(trainer.model, parallel.DistributedDataParallel): trainer.model._sync_params() checkpoint_utils.save_checkpoint.best = 0. force_refine_step = update_num_to_refine_step(num_updates) trainer.criterion.pool.clear() print("| Start refinement step:", force_refine_step) if num_updates >= max_update: break if hasattr(args, "progressive") and args.progressive: task.dataset("train").set_random_refine_step( args.refinetot, force_refine_step=force_refine_step) # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, trainer, task, epoch_itr, epoch_aux_itr): """Train the model for one epoch.""" # Update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) # Auxiliary iterator aux_itr = epoch_aux_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus) aux_itr = iterators.GroupedIterator(aux_itr, update_freq, restart_when_done=True) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): # Record gradients from auxiliary data aux_samples = next(aux_itr) trainer.train_step(aux_samples, update_params=False) # if hasattr(trainer.optimizer, "save_constraints"): trainer.optimizer.save_constraints() log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, trainer, task, epoch_itr, model, experiment_path, total_samples=None, last_epoch_num=0, restore=None): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) num_heads = args.decoder_attention_heads head_dim = args.decoder_embed_dim // num_heads if experiment_path is not None: with open(experiment_path, 'r') as f: swaps = json.load(f) mhr(model, swaps, head_dim, num_heads, epoch_itr.epoch) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False conf = { "encoder": [{ "self_attn": [] } for i in range(args.encoder_layers)], "decoder": [{ "self_attn": [], "enc_attn": [] } for i in range(args.decoder_layers)] } attentions = { "decoder": [{ "self_attn": [] } for i in range(args.decoder_layers)] } batch_regression = 1.0 - (total_samples / (160239 * 50)) for i, samples in enumerate(progress): with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): log_output = trainer.train_step(samples, batch_num=batch_regression) if log_output is None: # OOM, overflow, ... continue total_samples += model.decoder.layers[0].self_attn.bsz batch_regression = 1.0 - ( total_samples / (160239 * 40) ) # need to find more generic way to find total samples and epoch num. # Get Confidence for each Head. if args.head_confidence_method is not None: conf = get_batch_confs(model, conf, args) # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop, val_conf = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop: break if args.head_confidence_method is not None: conf = convert_confs(conf, args) path = args.save_dir.replace("checkpoints", "confs") + "-method={0}".format( args.head_confidence_method) try: os.mkdir(path, 0o775) except: pass with open( args.save_dir.replace("checkpoints", "confs") + "-method={0}".format(args.head_confidence_method) + "/epoch-{0}.pkl".format(epoch_itr.epoch), 'wb') as fd: pickle.dump(conf, fd, protocol=3) if args.dynamic_type is not None and args.head_confidence_method is not None: conf = val_conf restore['enc_self_attn'], last_epoch_num[ 'enc_self_attn'] = dynamic_mhr(model, int(args.start_dynamic_mhr[0]), "encoder", "self_attn", restore['enc_self_attn'], int(args.dynamic_swap_frequency[0]), last_epoch_num['enc_self_attn'], epoch_itr.epoch + 1, int(args.dynamic_max_switches[0]), conf[0], num_heads, head_dim, args.encoder_layers, local_only=False, d_type=args.dynamic_type[0], rest=int(args.dynamic_rest[0]), end_epoch=int( args.dynamic_end_epoch[0])) restore['dec_self_attn'], last_epoch_num[ 'dec_self_attn'] = dynamic_mhr(model, int(args.start_dynamic_mhr[1]), "decoder", "self_attn", restore['dec_self_attn'], int(args.dynamic_swap_frequency[1]), last_epoch_num['dec_self_attn'], epoch_itr.epoch + 1, int(args.dynamic_max_switches[1]), conf[1], num_heads, head_dim, args.encoder_layers, local_only=False, d_type=args.dynamic_type[1], rest=int(args.dynamic_rest[1]), end_epoch=int( args.dynamic_end_epoch[1])) restore['dec_enc_attn'], last_epoch_num['dec_enc_attn'] = dynamic_mhr( model, int(args.start_dynamic_mhr[2]), "decoder", "encoder_attn", restore['dec_enc_attn'], int(args.dynamic_swap_frequency[2]), last_epoch_num['dec_enc_attn'], epoch_itr.epoch + 1, int(args.dynamic_max_switches[2]), conf[2], num_heads, head_dim, args.encoder_layers, local_only=False, d_type=args.dynamic_type[2], rest=int(args.dynamic_rest[2]), end_epoch=int(args.dynamic_end_epoch[2])) # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop, total_samples, restore, last_epoch_num
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), ) update_freq = (cfg.optimization.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(cfg.optimization.update_freq) else cfg.optimization.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(cfg.common.tensorboard_logdir if distributed_utils.is_master( cfg.distributed_training) else None), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=(cfg.common.wandb_project if distributed_utils.is_master( cfg.distributed_training) else None), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)), azureml_logging=(cfg.common.azureml_logging if distributed_utils.is_master( cfg.distributed_training) else False), ) progress.update_config(_flatten_config(cfg)) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() logger.info("Start iterating over samples") for i, samples in enumerate(progress): with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % cfg.common.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save(cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format( epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def train(args, trainer, task, epoch_itr, max_update=math.inf, model=None): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') for i, samples in enumerate(progress): with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if(i==0): print('epoch: ', epoch_itr.epoch) endeattn_norm=[] selfattn_norm=[] for m in model.modules(): if(hasattr(m, 'selfattn_norm')): if(m.selfattn_norm != None): selfattn_norm.append(m.selfattn_norm) if(hasattr(m, 'endeattn_norm')): if(m.endeattn_norm != None): endeattn_norm.append(m.endeattn_norm) print('self attention norms: ', selfattn_norm) print('en/decoder attn norms:', endeattn_norm) valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return valid_losses
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples, epoch_itr=epoch_itr) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue if '_cls' in k or '_reg' in k or '_num' in k or '_acc' in k: continue extra_meters[k].update(v) stats[k] = extra_meters[k].avg for i in range(args.num_props): loss_log_key = '%d_cls' % i if i in args.cls_index else '%d_reg' % i sample_num = log_output.get('%d_num' % i, 0) extra_meters[loss_log_key].update(log_output.get(loss_log_key, 0), sample_num) stats[loss_log_key] = extra_meters[loss_log_key].avg if i in args.cls_index: cls_acc_key = '%d_acc' % i extra_meters[cls_acc_key].update( log_output.get(cls_acc_key, 0), sample_num) stats[cls_acc_key] = extra_meters[cls_acc_key].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() # Write Tensorboard. if num_updates % args.log_per_iter == 0: for k, v in stats.items(): if sum([ 1 for x in ['loss', 'ppl', 'ac', 'reg', 'lr'] if x in k ]) > 0: trainer.summary_writer.scalar_summary( 'train/' + k, float(v), num_updates) if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()