def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): if grad_norm is not None: metrics.log_speed("ups", 1., priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.args['optimization']['clip_norm'] > 0: metrics.log_scalar( "clip", torch.where( grad_norm > self.args['optimization']['clip_norm'], grad_norm.new_tensor(100), grad_norm.new_tensor(0), ), priority=500, round=1, ) with metrics.aggregate() as agg: if logging_outputs is not None: self.task.reduce_metrics(logging_outputs, self.get_criterion()) # support legacy interface logging_output = agg.get_smoothed_values() logging_output["sample_size"] = sample_size for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: if key_to_delete in logging_output: del logging_output[key_to_delete] return logging_output
def aggregate_logging_outputs(self, logging_outputs, criterion): """[deprecated] Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( "The aggregate_logging_outputs API is deprecated. " "Please use the reduce_metrics API instead." ) with metrics.aggregate() as agg: self.reduce_metrics(logging_outputs, criterion) return agg.get_smoothed_values()
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args['dataset']['fixed_validation_seed'] is not None: # set fixed seed for every validation set_seed.set_torch_seed(args['dataset']['fixed_validation_seed']) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args['dataset']['max_tokens_valid'], max_sentences=args['dataset']['max_sentences_valid'], max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args['dataset'] ['skip_invalid_size_inputs_valid_test'], required_batch_size_multiple=args['dataset'] ['required_batch_size_multiple'], seed=args['common']['seed'], num_shards=args['distributed_training']['distributed_world_size'], shard_id=args['distributed_training']['distributed_rank'], num_workers=args['dataset']['num_workers'], ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args['common']['log_format'], log_interval=args['common']['log_interval'], epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append( stats[args['checkpoint']['best_checkpoint_metric']]) return valid_losses
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args['distributed_training'] ['fix_batches_to_gpus'], # shuffle=(epoch_itr.next_epoch_idx > args['dataset']['curriculum']), shuffle=False, ) update_freq = (args['optimization']['update_freq'][epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args['optimization']['update_freq']) else args['optimization']['update_freq'][-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args['common']['log_format'], log_interval=args['common']['log_interval'], epoch=epoch_itr.epoch, tensorboard_logdir=(args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args['dataset']['valid_subset'].split(',') max_update = args['optimization']['max_update'] or math.inf num_updates = 0 # init as 0, for zero-shot learning for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args['common']['log_interval'] == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset epoch-level meters metrics.reset_meters('train_inner') if (not args['dataset']['disable_validation'] and args['checkpoint']['save_interval_updates'] > 0 and num_updates % args['checkpoint']['save_interval_updates'] == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train')
def validate(args, trainer, task, epoch_itr, valid_subsets, dev_subsets, dev_refs): """Evaluate the model on the validation set(s) and return the losses.""" if args['dataset']['fixed_validation_seed'] is not None: # set fixed seed for every validation utils.set_torch_seed(args['dataset']['fixed_validation_seed']) for subset in valid_subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args['dataset']['max_tokens_valid'], max_sentences=args['dataset']['max_sentences_valid'], max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args['dataset'] ['skip_invalid_size_inputs_valid_test'], required_batch_size_multiple=args['dataset'] ['required_batch_size_multiple'], seed=args['common']['seed'], num_shards=args['distributed_training']['distributed_world_size'], shard_id=args['distributed_training']['distributed_rank'], num_workers=args['dataset']['num_workers'], ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args['common']['log_format'], log_interval=args['common']['log_interval'], epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) # calculate accuracy match = stats.pop('match') total = stats.pop('total') valid_acc = match / total progress.print( { 'accuracy': f'{round(100. * valid_acc, 2)}%', 'bleu': stats['bleu'], 'loss': stats['loss'], }, tag=subset, step=trainer.get_num_updates()) # for subset in dev_subsets: # hypotheses, references = {}, dev_refs # # # Initialize data iterator # itr = task.get_batch_iterator( # dataset=task.dataset(subset), # max_tokens=args['dataset']['max_tokens_valid'], # max_sentences=args['dataset']['max_sentences_valid'], # max_positions=utils.resolve_max_positions( # task.max_positions(), # trainer.get_model().max_positions(), # ), # ignore_invalid_inputs=args['dataset']['skip_invalid_size_inputs_valid_test'], # required_batch_size_multiple=args['dataset']['required_batch_size_multiple'], # seed=args['common']['seed'], # num_shards=args['distributed_training']['distributed_world_size'], # shard_id=args['distributed_training']['distributed_rank'], # num_workers=args['dataset']['num_workers'], # ).next_epoch_itr(shuffle=False) # progress = progress_bar.progress_bar( # itr, # log_format=args['common']['log_format'], # log_interval=args['common']['log_interval'], # epoch=epoch_itr.epoch, # prefix=f"valid on '{subset}' subset", # tensorboard_logdir=( # args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None # ), # default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'), # ) # # # create a new root metrics aggregator so validation metrics # # don't pollute other aggregators (e.g., train meters) # with metrics.aggregate(new_root=True) as agg: # for sample in progress: # with torch.no_grad(): # trainer.model.eval() # trainer.criterion.eval() # sample = trainer._prepare_sample(sample) # hyps, _, _, ids = trainer.task.step_out(sample, trainer.model) # for idx, hypo in zip(ids, hyps): # hypotheses[idx] = hypo # # from third_party.pycocoevalcap.bleu.google_bleu import compute_bleu # assert set(hypotheses.keys()) == set(references.keys()) # bleus = [ # compute_bleu([references[idx]], [hypotheses[idx]], smooth=Trainer)[0] # for idx in hypotheses.keys() # ] # dev_bleu = round(100. * sum(bleus) / len(bleus), 2) # # log validation stats # stats = agg.get_smoothed_values() # stats['bleu'] = dev_bleu # stats = get_dev_stats(args, trainer, stats) # progress.print(stats, tag=subset, step=trainer.get_num_updates()) # return valid_acc, dev_bleu return valid_acc, None