Beispiel #1
0
    def test_learning_rate_without_decay_or_warmups(self):
        params = base_configs.LearningRateConfig(name='exponential',
                                                 initial_lr=0.01,
                                                 decay_rate=0.01,
                                                 decay_epochs=None,
                                                 warmup_epochs=None,
                                                 scale_by_batch_size=0.01,
                                                 examples_per_epoch=1,
                                                 boundaries=[0],
                                                 multipliers=[0, 1])
        batch_size = 1
        train_steps = 1

        lr = optimizer_factory.build_learning_rate(params=params,
                                                   batch_size=batch_size,
                                                   train_steps=train_steps)
        self.assertTrue(
            issubclass(type(lr),
                       tf.keras.optimizers.schedules.LearningRateSchedule))
Beispiel #2
0
    def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):
        """Basic smoke test for syntax."""
        params = base_configs.LearningRateConfig(name=lr_decay_type,
                                                 initial_lr=0.01,
                                                 decay_rate=0.01,
                                                 decay_epochs=1,
                                                 warmup_epochs=1,
                                                 scale_by_batch_size=0.01,
                                                 examples_per_epoch=1,
                                                 boundaries=[0],
                                                 multipliers=[0, 1])
        batch_size = 1
        train_epochs = 1
        train_steps = 1

        lr = optimizer_factory.build_learning_rate(params=params,
                                                   batch_size=batch_size,
                                                   train_epochs=train_epochs,
                                                   train_steps=train_steps)
        self.assertTrue(
            issubclass(type(lr),
                       tf.keras.optimizers.schedules.LearningRateSchedule))
def train_and_eval(
        params: base_configs.ExperimentConfig,
        strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]:
    """Runs the train and eval path using compile/fit."""
    logging.info('Running train and eval.')

    distribute_utils.configure_cluster(params.runtime.worker_hosts,
                                       params.runtime.task_index)

    # Note: for TPUs, strategy and scope should be created before the dataset
    strategy = strategy_override or distribute_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)

    strategy_scope = distribute_utils.get_strategy_scope(strategy)

    logging.info('Detected %d devices.',
                 strategy.num_replicas_in_sync if strategy else 1)

    label_smoothing = params.model.loss.label_smoothing
    one_hot = label_smoothing and label_smoothing > 0

    builders = _get_dataset_builders(params, strategy, one_hot)
    datasets = [
        builder.build(strategy) if builder else None for builder in builders
    ]

    # Unpack datasets and builders based on train/val/test splits
    train_builder, validation_builder = builders  # pylint: disable=unbalanced-tuple-unpacking
    train_dataset, validation_dataset = datasets

    train_epochs = params.train.epochs
    train_steps = params.train.steps or train_builder.num_steps
    validation_steps = params.evaluation.steps or validation_builder.num_steps

    initialize(params, train_builder)

    logging.info('Global batch size: %d', train_builder.global_batch_size)

    with strategy_scope:
        model_params = params.model.model_params.as_dict()
        model = get_models()[params.model.name](**model_params)
        learning_rate = optimizer_factory.build_learning_rate(
            params=params.model.learning_rate,
            batch_size=train_builder.global_batch_size,
            train_epochs=train_epochs,
            train_steps=train_steps)
        optimizer = optimizer_factory.build_optimizer(
            optimizer_name=params.model.optimizer.name,
            base_learning_rate=learning_rate,
            params=params.model.optimizer.as_dict(),
            model=model)
        optimizer = performance.configure_optimizer(
            optimizer,
            use_float16=train_builder.dtype == 'float16',
            loss_scale=get_loss_scale(params))

        metrics_map = _get_metrics(one_hot)
        metrics = [metrics_map[metric] for metric in params.train.metrics]
        steps_per_loop = train_steps if params.train.set_epoch_loop else 1

        if one_hot:
            loss_obj = tf.keras.losses.CategoricalCrossentropy(
                label_smoothing=params.model.loss.label_smoothing)
        else:
            loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
        model.compile(optimizer=optimizer,
                      loss=loss_obj,
                      metrics=metrics,
                      steps_per_execution=steps_per_loop)

        initial_epoch = 0
        if params.train.resume_checkpoint:
            initial_epoch = resume_from_checkpoint(model=model,
                                                   model_dir=params.model_dir,
                                                   train_steps=train_steps)

        callbacks = custom_callbacks.get_callbacks(
            model_checkpoint=params.train.callbacks.
            enable_checkpoint_and_export,
            include_tensorboard=params.train.callbacks.enable_tensorboard,
            time_history=params.train.callbacks.enable_time_history,
            track_lr=params.train.tensorboard.track_lr,
            write_model_weights=params.train.tensorboard.write_model_weights,
            initial_step=initial_epoch * train_steps,
            batch_size=train_builder.global_batch_size,
            log_steps=params.train.time_history.log_steps,
            model_dir=params.model_dir,
            backup_and_restore=params.train.callbacks.enable_backup_and_restore
        )

    serialize_config(params=params, model_dir=params.model_dir)

    if params.evaluation.skip_eval:
        validation_kwargs = {}
    else:
        validation_kwargs = {
            'validation_data': validation_dataset,
            'validation_steps': validation_steps,
            'validation_freq': params.evaluation.epochs_between_evals,
        }

    history = model.fit(train_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        initial_epoch=initial_epoch,
                        callbacks=callbacks,
                        verbose=2,
                        **validation_kwargs)

    validation_output = None
    if not params.evaluation.skip_eval:
        validation_output = model.evaluate(validation_dataset,
                                           steps=validation_steps,
                                           verbose=2)

    # TODO(dankondratyuk): eval and save final test accuracy
    stats = common.build_stats(history, validation_output, callbacks)
    return stats