示例#1
0
def validate_and_save(
    cfg: DictConfig,
    trainer: Trainer,
    task: tasks.FairseqTask,
    epoch_itr,
    valid_subsets: List[str],
    end_of_epoch: bool,
) -> Tuple[List[Optional[float]], bool]:
    num_updates = trainer.get_num_updates()
    max_update = cfg.optimization.max_update or math.inf

    # Stopping conditions (and an additional one based on validation loss later
    # on)
    should_stop = False
    if num_updates >= max_update:
        should_stop = True
        logger.info(f"Stopping training due to "
                    f"num_updates: {num_updates} >= max_update: {max_update}")

    training_time_hours = trainer.cumulative_training_time() / (60 * 60)
    if (cfg.optimization.stop_time_hours > 0
            and training_time_hours > cfg.optimization.stop_time_hours):
        should_stop = True
        logger.info(
            f"Stopping training due to "
            f"cumulative_training_time: {training_time_hours} > "
            f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)")

    do_save = ((end_of_epoch
                and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
               or should_stop
               or (cfg.checkpoint.save_interval_updates > 0 and num_updates > 0
                   and num_updates % cfg.checkpoint.save_interval_updates == 0
                   and num_updates >= cfg.dataset.validate_after_updates))
    do_validate = (
        (not end_of_epoch and do_save)  # validate during mid-epoch saves
        or
        (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
        or should_stop or
        (cfg.dataset.validate_interval_updates > 0 and num_updates > 0
         and num_updates % cfg.dataset.validate_interval_updates
         == 0)) and not cfg.dataset.disable_validation

    # if there is a need to validate and we should keep the N>0 best checkpoints then "do_save" should be "on" anyway
    do_save = do_save or (do_validate
                          and cfg.checkpoint.keep_best_checkpoints > 0)

    # Validate
    valid_losses = [None]
    if do_validate:
        valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)

    should_stop |= should_stop_early(cfg, valid_losses[0])

    # Save checkpoint
    if do_save or should_stop:
        checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr,
                                         valid_losses[0])

    return valid_losses, should_stop
示例#2
0
def validate_and_save(
    cfg: DictConfig,
    trainer: Trainer,
    task: tasks.FairseqTask,
    epoch_itr,
    valid_subsets: List[str],
    end_of_epoch: bool,
) -> Tuple[List[Optional[float]], bool]:
    num_updates = trainer.get_num_updates()
    max_update = cfg.optimization.max_update or math.inf
    do_save = (
        (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
        or num_updates >= max_update
        or (
            cfg.checkpoint.save_interval_updates > 0
            and num_updates > 0
            and num_updates % cfg.checkpoint.save_interval_updates == 0
            and num_updates >= cfg.dataset.validate_after_updates
        )
    )
    do_validate = (
        (not end_of_epoch and do_save)  # validate during mid-epoch saves
        or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
        or num_updates >= max_update
        or (
            cfg.dataset.validate_interval_updates > 0
            and num_updates > 0
            and num_updates % cfg.dataset.validate_interval_updates == 0
        )
    ) and not cfg.dataset.disable_validation

    # Validate
    valid_losses = [None]
    if do_validate:
        valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)

    # Stopping conditions
    should_stop = (
        should_stop_early(cfg, valid_losses[0])
        or num_updates >= max_update
        or (
            cfg.optimization.stop_time_hours > 0
            and trainer.cumulative_training_time() / (60 * 60)
            > cfg.optimization.stop_time_hours
        )
    )

    # Save checkpoint
    if do_save or should_stop:
        logger.info("begin save checkpoint")
        checkpoint_utils.save_checkpoint(
            cfg.checkpoint, trainer, epoch_itr, valid_losses[0]
        )

    return valid_losses, should_stop