Ejemplo n.º 1
0
 def test_get_loss_scale(self, loss_scale, dtype, expected):
   config = base_configs.ExperimentConfig(
       runtime=base_configs.RuntimeConfig(
           loss_scale=loss_scale),
       train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
   ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
   self.assertEqual(ls, expected)
Ejemplo n.º 2
0
class ResNetImagenetConfig(base_configs.ExperimentConfig):
    """Base configuration to train resnet-50 on ImageNet."""
    export: base_configs.ExportConfig = base_configs.ExportConfig()
    runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
    train_dataset: dataset_factory.DatasetConfig = \
        dataset_factory.ImageNetConfig(split='train',
                                       one_hot=False,
                                       mean_subtract=True,
                                       standardize=True)
    validation_dataset: dataset_factory.DatasetConfig = \
        dataset_factory.ImageNetConfig(split='validation',
                                       one_hot=False,
                                       mean_subtract=True,
                                       standardize=True)
    train: base_configs.TrainConfig = base_configs.TrainConfig(
        resume_checkpoint=True,
        epochs=90,
        steps=None,
        callbacks=base_configs.CallbacksConfig(
            enable_checkpoint_and_export=True, enable_tensorboard=True),
        metrics=['accuracy', 'top_5'],
        time_history=base_configs.TimeHistoryConfig(log_steps=100),
        tensorboard=base_configs.TensorboardConfig(track_lr=True,
                                                   write_model_weights=False),
        set_epoch_loop=False)
    evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
        epochs_between_evals=1, steps=None)
    model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
Ejemplo n.º 3
0
class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
    """Base configuration to train efficientnet-b0 on ImageNet.

  Attributes:
    export: An `ExportConfig` instance
    runtime: A `RuntimeConfig` instance.
    dataset: A `DatasetConfig` instance.
    train: A `TrainConfig` instance.
    evaluation: An `EvalConfig` instance.
    model: A `ModelConfig` instance.

  """
    export: base_configs.ExportConfig = base_configs.ExportConfig()
    runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
    train_dataset: dataset_factory.DatasetConfig = \
        dataset_factory.ImageNetConfig(split='train')
    validation_dataset: dataset_factory.DatasetConfig = \
        dataset_factory.ImageNetConfig(split='validation')
    train: base_configs.TrainConfig = base_configs.TrainConfig(
        resume_checkpoint=True,
        epochs=500,
        steps=None,
        callbacks=base_configs.CallbacksConfig(
            enable_checkpoint_and_export=True, enable_tensorboard=True),
        metrics=['accuracy', 'top_5'],
        time_history=base_configs.TimeHistoryConfig(log_steps=100),
        tensorboard=base_configs.TensorboardConfig(track_lr=True,
                                                   write_model_weights=False),
        set_epoch_loop=False)
    evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
        epochs_between_evals=1, steps=None)
    model: base_configs.ModelConfig = \
      efficientnet_config.EfficientNetModelConfig()
Ejemplo n.º 4
0
 def test_initialize(self, dtype):
   config = base_configs.ExperimentConfig(
       runtime=base_configs.RuntimeConfig(
           enable_eager=False,
           enable_xla=False,
           gpu_threads_enabled=True,
           per_gpu_thread_count=1,
           gpu_thread_mode='gpu_private',
           num_gpus=1,
           dataset_num_private_threads=1,
       ),
       train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
       model=base_configs.ModelConfig(
           loss=base_configs.LossConfig(loss_scale='dynamic')),
   )
   classifier_trainer.initialize(config)
Ejemplo n.º 5
0
  def test_initialize(self, dtype):
    config = base_configs.ExperimentConfig(
        runtime=base_configs.RuntimeConfig(
            run_eagerly=False,
            enable_xla=False,
            per_gpu_thread_count=1,
            gpu_thread_mode='gpu_private',
            num_gpus=1,
            dataset_num_private_threads=1,
        ),
        train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
        model=base_configs.ModelConfig(),
    )

    class EmptyClass:
      pass
    fake_ds_builder = EmptyClass()
    fake_ds_builder.dtype = dtype
    fake_ds_builder.config = EmptyClass()
    classifier_trainer.initialize(config, fake_ds_builder)