def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): if grad_norm is not None: metrics.log_speed("ups", 1., priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.args['optimization']['clip_norm'] > 0: metrics.log_scalar( "clip", torch.where( grad_norm > self.args['optimization']['clip_norm'], grad_norm.new_tensor(100), grad_norm.new_tensor(0), ), priority=500, round=1, ) with metrics.aggregate() as agg: if logging_outputs is not None: self.task.reduce_metrics(logging_outputs, self.get_criterion()) # support legacy interface logging_output = agg.get_smoothed_values() logging_output["sample_size"] = sample_size for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: if key_to_delete in logging_output: del logging_output[key_to_delete] return logging_output
def aggregate_logging_outputs(self, logging_outputs, criterion): """[deprecated] Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( "The aggregate_logging_outputs API is deprecated. " "Please use the reduce_metrics API instead.") with metrics.aggregate() as agg: self.reduce_metrics(logging_outputs, criterion) return agg.get_smoothed_values()
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args['dataset']['fixed_validation_seed'] is not None: # set fixed seed for every validation utils.set_torch_seed(args['dataset']['fixed_validation_seed']) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args['dataset']['max_tokens_valid'], max_sentences=args['dataset']['max_sentences_valid'], max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args['dataset'] ['skip_invalid_size_inputs_valid_test'], required_batch_size_multiple=args['dataset'] ['required_batch_size_multiple'], seed=args['common']['seed'], num_shards=args['distributed_training']['distributed_world_size'], shard_id=args['distributed_training']['distributed_rank'], num_workers=args['dataset']['num_workers'], ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args['common']['log_format'], log_interval=args['common']['log_interval'], epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append( stats[args['checkpoint']['best_checkpoint_metric']]) return valid_losses
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args['distributed_training'] ['fix_batches_to_gpus'], shuffle=(epoch_itr.next_epoch_idx > args['dataset']['curriculum']), ) update_freq = (args['optimization']['update_freq'][epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args['optimization']['update_freq']) else args['optimization']['update_freq'][-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args['common']['log_format'], log_interval=args['common']['log_interval'], epoch=epoch_itr.epoch, tensorboard_logdir=(args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args['dataset']['valid_subset'].split(',') max_update = args['optimization']['max_update'] or math.inf for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args['common']['log_interval'] == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset epoch-level meters metrics.reset_meters('train_inner') if (not args['dataset']['disable_validation'] and args['checkpoint']['save_interval_updates'] > 0 and num_updates % args['checkpoint']['save_interval_updates'] == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')