def validate_and_save( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool, ) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() max_update = cfg.optimization.max_update or math.inf # Stopping conditions (and an additional one based on validation loss later # on) should_stop = False if num_updates >= max_update: should_stop = True logger.info(f"Stopping training due to " f"num_updates: {num_updates} >= max_update: {max_update}") training_time_hours = trainer.cumulative_training_time() / (60 * 60) if (cfg.optimization.stop_time_hours > 0 and training_time_hours > cfg.optimization.stop_time_hours): should_stop = True logger.info( f"Stopping training due to " f"cumulative_training_time: {training_time_hours} > " f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)") do_save = ((end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) or should_stop or (cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 and num_updates % cfg.checkpoint.save_interval_updates == 0 and num_updates >= cfg.dataset.validate_after_updates)) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) or should_stop or (cfg.dataset.validate_interval_updates > 0 and num_updates > 0 and num_updates % cfg.dataset.validate_interval_updates == 0)) and not cfg.dataset.disable_validation # if there is a need to validate and we should keep the N>0 best checkpoints then "do_save" should be "on" anyway do_save = do_save or (do_validate and cfg.checkpoint.keep_best_checkpoints > 0) # Validate valid_losses = [None] if do_validate: valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) should_stop |= should_stop_early(cfg, valid_losses[0]) # Save checkpoint if do_save or should_stop: checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0]) return valid_losses, should_stop
def validate_and_save( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool, ) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() max_update = cfg.optimization.max_update or math.inf do_save = ( (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) or num_updates >= max_update or ( cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 and num_updates % cfg.checkpoint.save_interval_updates == 0 and num_updates >= cfg.dataset.validate_after_updates ) ) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) or num_updates >= max_update or ( cfg.dataset.validate_interval_updates > 0 and num_updates > 0 and num_updates % cfg.dataset.validate_interval_updates == 0 ) ) and not cfg.dataset.disable_validation # Validate valid_losses = [None] if do_validate: valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) # Stopping conditions should_stop = ( should_stop_early(cfg, valid_losses[0]) or num_updates >= max_update or ( cfg.optimization.stop_time_hours > 0 and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours ) ) # Save checkpoint if do_save or should_stop: logger.info("begin save checkpoint") checkpoint_utils.save_checkpoint( cfg.checkpoint, trainer, epoch_itr, valid_losses[0] ) return valid_losses, should_stop