def save_and_eval( args, trainer, task, extra_state: Dict[str, Any], do_eval_tune_loss: bool, do_save: bool, do_eval_bleu: bool, ) -> Tuple[Dict[str, Any], bool, Optional[list]]: # Clear any remaining metrics from previous steps. This should already # have been done before, but just in case - to make sure we catch # any case where extra_case does not get populated correctly. extra_state = clear_per_step_extra_state(extra_state) # Under multiprocessing, each process will run eval over a different # shard of the tune data set and then aggregate the results across all # processes, so the eval stats from all processes' trainer should # remain synchronized. stop_due_to_tune_loss = False if do_eval_tune_loss: extra_state, stop_due_to_tune_loss = eval_tune_loss( args=args, trainer=trainer, task=task, subset=args.valid_subset, extra_state=extra_state, ) # Only save checkpoints and eval tune BLEU on the master - all other # processes will just get the results from the master. master_extra_state = None master_stop_training = None translation_samples = None if distributed_utils.is_master(args): stop_due_to_tune_bleu = False if do_save: extra_state = save_checkpoint( trainer=trainer, args=args, extra_state=extra_state ) if do_eval_bleu and not do_save: raise ValueError( "do_save should always be true when do_eval_bleu is true " "since a new BLEU eval can only be done when there's a new " "checkpoint." ) if do_eval_bleu: extra_state, stop_due_to_tune_bleu, translation_samples = evaluate_bleu( args=args, task=task, extra_state=extra_state ) master_extra_state = extra_state master_stop_training = stop_due_to_tune_loss or stop_due_to_tune_bleu # We don't all_gather the translation_samples since the sample sentences # could be pretty long, and only the master uses it anyway. extra_state, stop_training = pytorch_translate_utils.all_gather_from_master( args=args, data=[master_extra_state, master_stop_training] ) # Basic sanity checks that extra_state is populated correctly. assert not ( do_eval_tune_loss and ( extra_state["tune_eval"]["loss"] is None or extra_state["tune_eval"]["perplexity"] is None ) ) assert not (do_eval_bleu and extra_state["tune_bleu"]["current"] is None) return extra_state, stop_training, translation_samples
def save_and_eval( args, trainer, task, extra_state: Dict[str, Any], checkpoint_manager: Optional[checkpoint.CheckpointManager], end_of_epoch=False, ) -> Tuple[Dict[str, Any], bool, Optional[List]]: # Checks for time limit stopping criterion even when we're not doing # eval/saving checkpoints. max_update = args.max_update or math.inf stop_due_to_max_update = trainer.get_num_updates() > max_update stop_due_to_time_limit = is_training_over_time_limit( extra_state, args.stop_time_hr) if not end_of_epoch and ( args.save_interval_updates <= 0 or (extra_state["num_iterations"] % args.save_interval_updates != 0)): return extra_state, stop_due_to_time_limit, None # Update training time before saving the checkpoint. time_now: float = time.time() extra_state[ "previous_training_time"] += time_now - extra_state["start_time"] extra_state["start_time"] = time_now # Under multiprocessing, each process will run eval over a different # shard of the tune data set and then aggregate the results across all # processes, so the eval stats from all processes' trainer should # remain synchronized. # Tune loss extra_state, stop_due_to_tune_loss = eval_tune_loss( args=args, trainer=trainer, task=task, subset=args.valid_subset, extra_state=extra_state, ) is_master: bool = distributed_utils.is_master(args) if is_master: assert checkpoint_manager is not None, ( f"Master worker (rank {args.distributed_rank}) should " f"have a checkpoint_manager defined.") else: assert checkpoint_manager is None, ( f"Non-master worker (rank {args.distributed_rank}) should not " f"have a checkpoint_manager defined.") # trick to prepare the task for evaluation, e.g. in latent variable model we need to set eval_key in RoundRobinZipDataset if hasattr(task, "prepare_for_eval") and callable(task.prepare_for_eval): task.prepare_for_eval() # Only save checkpoints and eval tune BLEU on the master - all other # processes will just get the results from the master. translation_samples: Optional[List] = None if is_master: averaged_params: OrderedDict = checkpoint_manager.get_averaged_params( new_params=trainer.get_model().state_dict()) # TODO: fix after masked lm work completes if "save_only" not in args or not args.save_only: ( extra_state, stop_due_to_tune_bleu, new_best_averaged_checkpoint, translation_samples, ) = evaluate_bleu( args=args, task=task, extra_state=extra_state, trainer=trainer, averaged_params=averaged_params, ) else: new_best_averaged_checkpoint = True stop_due_to_tune_bleu = False # checkpoint_manager takes ownership of averaged_params. extra_state = checkpoint_manager.save( args=args, trainer=trainer, extra_state=extra_state, new_averaged_params=averaged_params, ) if new_best_averaged_checkpoint: checkpoint_manager.save_best_averaged_checkpoint( args=args, trainer=trainer, extra_state=extra_state) if hasattr(task, "prepare_for_train") and callable(task.prepare_for_train): task.prepare_for_train() # extra_state["tune_bleu"] needs to be sync'ed between master and workers # since we only do BLEU eval on master, but then need that info for # determining when to do lr_shrink on all workers. master_tune_bleu = None master_stop_training = None if is_master: master_tune_bleu = extra_state["tune_bleu"] master_stop_training = (stop_due_to_time_limit or stop_due_to_tune_loss or stop_due_to_tune_bleu or stop_due_to_max_update) tune_bleu, stop_training = pytorch_translate_utils.all_gather_from_master( args=args, data=[master_tune_bleu, master_stop_training]) extra_state["tune_bleu"] = tune_bleu # TODO: fix after masked lm work completes if "save_only" not in args or not args.save_only: # Basic sanity checks that extra_state is populated correctly. assert (extra_state["tune_eval"]["loss"] is not None and extra_state["tune_eval"]["perplexity"] is not None and extra_state["tune_bleu"]["current"] is not None) return extra_state, stop_training, translation_samples
def save_and_eval( args, trainer, task, extra_state: Dict[str, Any], checkpoint_manager: Optional[checkpoint.CheckpointManager], end_of_epoch=False, ) -> Tuple[Dict[str, Any], bool, Optional[List]]: # Checks for time limit stopping criterion even when we're not doing # eval/saving checkpoints. max_update = args.max_update or math.inf stop_due_to_max_update = trainer.get_num_updates() > max_update stop_due_to_time_limit = is_training_over_time_limit( extra_state, args.stop_time_hr) if not end_of_epoch and ( args.save_interval_updates <= 0 or (extra_state["num_iterations"] % args.save_interval_updates != 0)): return extra_state, stop_due_to_time_limit # Update training time before saving the checkpoint. time_now: float = time.time() extra_state[ "previous_training_time"] += time_now - extra_state["start_time"] extra_state["start_time"] = time_now # Under multiprocessing, each process will run eval over a different # shard of the tune data set and then aggregate the results across all # processes, so the eval stats from all processes' trainer should # remain synchronized. # Tune loss extra_state, stop_due_to_tune_loss = eval_tune_loss( args=args, trainer=trainer, task=task, subset=args.valid_subset, extra_state=extra_state, ) is_master: bool = distributed_utils.is_master(args) if is_master: assert checkpoint_manager is not None, ( f"Master worker (rank {args.distributed_rank}) should " f"have a checkpoint_manager defined.") else: assert checkpoint_manager is None, ( f"Non-master worker (rank {args.distributed_rank}) should not " f"have a checkpoint_manager defined.") if is_master: averaged_params: OrderedDict = checkpoint_manager.get_averaged_params( new_params=trainer.get_model().state_dict()) new_best_averaged_checkpoint = extra_state["tune_eval"][ "num_since_best"] == 0 # checkpoint_manager takes ownership of averaged_params. extra_state = checkpoint_manager.save( args=args, trainer=trainer, extra_state=extra_state, new_averaged_params=averaged_params, ) if new_best_averaged_checkpoint: checkpoint_manager.save_best_averaged_checkpoint( args=args, trainer=trainer, extra_state=extra_state) master_stop_training = None if is_master: master_stop_training = (stop_due_to_time_limit or stop_due_to_tune_loss or stop_due_to_max_update) stop_training = pytorch_translate_utils.all_gather_from_master( args=args, data=[master_stop_training])[0] # TODO: fix after masked lm work completes if "save_only" not in args or not args.save_only: # Basic sanity checks that extra_state is populated correctly. assert (extra_state["tune_eval"]["loss"] is not None and extra_state["tune_eval"]["perplexity"] is not None) return extra_state, stop_training
def save_and_eval( args, trainer, task, extra_state: Dict[str, Any], checkpoint_manager: Optional[checkpoint.CheckpointManager], end_of_epoch=False, ) -> Tuple[Dict[str, Any], bool, Optional[List]]: # Checks for time limit stopping criterion even when we're not doing # eval/saving checkpoints. stop_due_to_time_limit = is_training_over_time_limit( extra_state, args.stop_time_hr) if not end_of_epoch and ( args.save_interval_updates <= 0 or (extra_state["num_iterations"] % args.save_interval_updates != 0)): return extra_state, stop_due_to_time_limit, None # Update training time before saving the checkpoint. time_now: float = time.time() extra_state[ "previous_training_time"] += time_now - extra_state["start_time"] extra_state["start_time"] = time_now # Under multiprocessing, each process will run eval over a different # shard of the tune data set and then aggregate the results across all # processes, so the eval stats from all processes' trainer should # remain synchronized. # Tune loss extra_state, stop_due_to_tune_loss = eval_tune_loss( args=args, trainer=trainer, task=task, subset=args.valid_subset, extra_state=extra_state, ) is_master: bool = distributed_utils.is_master(args) if is_master: assert checkpoint_manager is not None, ( f"Master worker (rank {args.distributed_rank}) should " f"have a checkpoint_manager defined.") else: assert checkpoint_manager is None, ( f"Non-master worker (rank {args.distributed_rank}) should not " f"have a checkpoint_manager defined.") # Only save checkpoints and eval tune BLEU on the master - all other # processes will just get the results from the master. translation_samples: Optional[List] = None if is_master: averaged_params: OrderedDict = checkpoint_manager.get_averaged_params( new_params=trainer.get_model().state_dict()) extra_state, stop_due_to_tune_bleu, new_best_averaged_checkpoint, translation_samples = evaluate_bleu( args=args, task=task, extra_state=extra_state, trainer=trainer, averaged_params=averaged_params, ) # checkpoint_manager takes ownership of averaged_params. extra_state = checkpoint_manager.save( args=args, trainer=trainer, extra_state=extra_state, new_averaged_params=averaged_params, ) if new_best_averaged_checkpoint: checkpoint_manager.save_best_averaged_checkpoint( args=args, trainer=trainer, extra_state=extra_state) # We don't all_gather the translation_samples since the sample sentences # could be pretty long, and only the master uses it anyway. master_extra_state = None master_stop_training = None if is_master: master_extra_state = extra_state master_stop_training = (stop_due_to_time_limit or stop_due_to_tune_loss or stop_due_to_tune_bleu) extra_state, stop_training = pytorch_translate_utils.all_gather_from_master( args=args, data=[master_extra_state, master_stop_training]) # Basic sanity checks that extra_state is populated correctly. assert (extra_state["tune_eval"]["loss"] is not None and extra_state["tune_eval"]["perplexity"] is not None and extra_state["tune_bleu"]["current"] is not None) return extra_state, stop_training, translation_samples