示例#1
0
def resume_from_checkpoint(model: tf.keras.Model, model_dir: str, params,
                           flags, train_steps: int) -> int:
    """Resumes from the latest checkpoint, if possible.

  Loads the model weights and optimizer settings from a checkpoint.
  This function should be used in case of preemption recovery.

  Args:
    model: The model whose weights should be restored.
    model_dir: The directory where model weights were saved.
    train_steps: The number of steps to train.

  Returns:
    The epoch of the latest checkpoint, or 0 if not restoring.

  """
    logging.info('Load from checkpoint is enabled.')
    latest_checkpoint = tf.train.latest_checkpoint(model_dir)
    logging.info('latest_checkpoint: %s', latest_checkpoint)
    if latest_checkpoint:
        logging.info(
            'Checkpoint file %s found and restoring from '
            'checkpoint', latest_checkpoint)
        model.load_weights(latest_checkpoint)
    elif flags.init_chkpt:
        logging.info('Load init checkpoint from: %s', flags.init_chkpt)
        model.load_weights(flags.init_chkpt)
        # model.optimizer.iterations.assign(int(model.optimizer.iterations * flags.SWITCH_FROM / flags.SWITCH_TO))
        model.optimizer.iterations.assign(
            int(train_steps * int(flags.init_chkpt[-4:])))
    else:
        logging.info('No checkpoint detected.')
        return 0

    if flags.freeze_lr:
        learning_rate = optimizer_factory.build_learning_rate(
            params=params.model.learning_rate,
            batch_size=int(flags.SWITCH_FROM),
            train_steps=train_steps)
        model.optimizer.lr = learning_rate

    initial_epoch = model.optimizer.iterations // train_steps
    # Check the epoch count
    if not latest_checkpoint and flags.init_chkpt:
        assert initial_epoch == int(
            flags.init_chkpt[-4:]), (model.optimizer.iterations, train_steps,
                                     initial_epoch, flags.init_chkpt[-4:])
    logging.info('Completed loading from checkpoint.')
    logging.info('Resuming from epoch %d', initial_epoch)
    return int(initial_epoch)
示例#2
0
    def test_learning_rate_without_decay_or_warmups(self):
        params = base_configs.LearningRateConfig(name='exponential',
                                                 initial_lr=0.01,
                                                 decay_rate=0.01,
                                                 decay_epochs=None,
                                                 warmup_epochs=None,
                                                 scale_by_batch_size=0.01,
                                                 examples_per_epoch=1,
                                                 boundaries=[0],
                                                 multipliers=[0, 1])
        batch_size = 1
        train_steps = 1

        lr = optimizer_factory.build_learning_rate(params=params,
                                                   batch_size=batch_size,
                                                   train_steps=train_steps)
        self.assertTrue(
            issubclass(type(lr),
                       tf.keras.optimizers.schedules.LearningRateSchedule))
示例#3
0
    def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):
        """Basic smoke test for syntax."""
        params = base_configs.LearningRateConfig(name=lr_decay_type,
                                                 initial_lr=0.01,
                                                 decay_rate=0.01,
                                                 decay_epochs=1,
                                                 warmup_epochs=1,
                                                 scale_by_batch_size=0.01,
                                                 examples_per_epoch=1,
                                                 boundaries=[0],
                                                 multipliers=[0, 1])
        batch_size = 1
        train_steps = 1

        lr = optimizer_factory.build_learning_rate(params=params,
                                                   batch_size=batch_size,
                                                   train_steps=train_steps)
        self.assertTrue(
            issubclass(type(lr),
                       tf.keras.optimizers.schedules.LearningRateSchedule))
def train_and_eval(
        params: base_configs.ExperimentConfig,
        strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]:
    """Runs the train and eval path using compile/fit."""
    logging.info('Running train and eval.')

    # Note: for TPUs, strategy and scope should be created before the dataset
    strategy = strategy_override or distribution_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)

    strategy_scope = distribution_utils.get_strategy_scope(strategy)

    logging.info('Detected %d devices.',
                 strategy.num_replicas_in_sync if strategy else 1)

    label_smoothing = params.model.loss.label_smoothing
    one_hot = label_smoothing and label_smoothing > 0

    builders = _get_dataset_builders(params, strategy, one_hot)
    datasets = [
        builder.build(strategy) if builder else None for builder in builders
    ]

    # Unpack datasets and builders based on train/val/test splits
    train_builder, validation_builder = builders  # pylint: disable=unbalanced-tuple-unpacking
    train_dataset, validation_dataset = datasets

    train_epochs = params.train.epochs
    train_steps = params.train.steps or train_builder.num_steps
    validation_steps = params.evaluation.steps or validation_builder.num_steps

    initialize(params, train_builder)

    logging.info('Global batch size: %d', train_builder.global_batch_size)

    with strategy_scope:
        model_params = params.model.model_params.as_dict()
        model = get_models()[params.model.name](**model_params)
        learning_rate = optimizer_factory.build_learning_rate(
            params=params.model.learning_rate,
            batch_size=train_builder.global_batch_size,
            train_steps=train_steps)
        optimizer = optimizer_factory.build_optimizer(
            optimizer_name=params.model.optimizer.name,
            base_learning_rate=learning_rate,
            params=params.model.optimizer.as_dict())

        metrics_map = _get_metrics(one_hot)
        metrics = [metrics_map[metric] for metric in params.train.metrics]

        if one_hot:
            loss_obj = tf.keras.losses.CategoricalCrossentropy(
                label_smoothing=params.model.loss.label_smoothing)
        else:
            loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
        model.compile(
            optimizer=optimizer,
            loss=loss_obj,
            metrics=metrics,
            experimental_steps_per_execution=params.train.steps_per_loop)

        initial_epoch = 0
        if params.train.resume_checkpoint:
            initial_epoch = resume_from_checkpoint(model=model,
                                                   model_dir=params.model_dir,
                                                   train_steps=train_steps)

        callbacks = custom_callbacks.get_callbacks(
            model_checkpoint=params.train.callbacks.
            enable_checkpoint_and_export,
            include_tensorboard=params.train.callbacks.enable_tensorboard,
            time_history=params.train.callbacks.enable_time_history,
            track_lr=params.train.tensorboard.track_lr,
            write_model_weights=params.train.tensorboard.write_model_weights,
            initial_step=initial_epoch * train_steps,
            batch_size=train_builder.global_batch_size,
            log_steps=params.train.time_history.log_steps,
            model_dir=params.model_dir)

    serialize_config(params=params, model_dir=params.model_dir)

    if params.evaluation.skip_eval:
        validation_kwargs = {}
    else:
        validation_kwargs = {
            'validation_data': validation_dataset,
            'validation_steps': validation_steps,
            'validation_freq': params.evaluation.epochs_between_evals,
        }

    history = model.fit(train_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        initial_epoch=initial_epoch,
                        callbacks=callbacks,
                        verbose=2,
                        **validation_kwargs)

    validation_output = None
    if not params.evaluation.skip_eval:
        validation_output = model.evaluate(validation_dataset,
                                           steps=validation_steps,
                                           verbose=2)

    # TODO(dankondratyuk): eval and save final test accuracy
    stats = common.build_stats(history, validation_output, callbacks)
    return stats