Esempio n. 1
0
class ResNetModelConfig(base_configs.ModelConfig):
    """Configuration for the ResNet model."""
    name: str = 'ResNet'
    num_classes: int = 1000
    model_params: base_config.Config = dataclasses.field(
        default_factory=lambda: {
            'num_classes': 1000,
            'batch_size': None,
            'use_l2_regularizer': True,
            'rescale_inputs': False,
        })
    loss: base_configs.LossConfig = base_configs.LossConfig(
        name='sparse_categorical_crossentropy')
    optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
        name='momentum',
        decay=0.9,
        epsilon=0.001,
        momentum=0.9,
        moving_average_decay=None)
    learning_rate: base_configs.LearningRateConfig = (
        base_configs.LearningRateConfig(
            name='stepwise',
            initial_lr=0.1,
            examples_per_epoch=1281167,
            boundaries=[30, 60, 80],
            warmup_epochs=5,
            scale_by_batch_size=1. / 256.,
            multipliers=[0.1 / 256, 0.01 / 256, 0.001 / 256, 0.0001 / 256]))
 def test_get_loss_scale(self, loss_scale, dtype, expected):
   config = base_configs.ExperimentConfig(
       model=base_configs.ModelConfig(
           loss=base_configs.LossConfig(loss_scale=loss_scale)),
       train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
   ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
   self.assertEqual(ls, expected)
Esempio n. 3
0
class ResNetModelConfig(base_configs.ModelConfig):
  """Configuration for the ResNet model."""
  name: str = 'ResNet'
  num_classes: int = 1000
  model_params: base_config.Config = dataclasses.field(
      default_factory=lambda: {
          'num_classes': 1000,
          'batch_size': None,
          'use_l2_regularizer': True,
          'rescale_inputs': False,
      })
  loss: base_configs.LossConfig = base_configs.LossConfig(
      name='sparse_categorical_crossentropy')
  optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
      name='momentum',
      decay=0.9,
      epsilon=0.001,
      momentum=0.9,
      moving_average_decay=None)
  learning_rate: base_configs.LearningRateConfig = (
      base_configs.LearningRateConfig(
          name='piecewise_constant_with_warmup',
          examples_per_epoch=1281167,
          warmup_epochs=_RESNET_LR_WARMUP_EPOCHS,
          boundaries=_RESNET_LR_BOUNDARIES,
          multipliers=_RESNET_LR_MULTIPLIERS))
class EfficientNetModelConfig(base_configs.ModelConfig):
  """Configuration for the EfficientNet model.

  This configuration will default to settings used for training efficientnet-b0
  on a v3-8 TPU on ImageNet.

  Attributes:
    name: The name of the model. Defaults to 'EfficientNet'.
    num_classes: The number of classes in the model.
    model_params: A dictionary that represents the parameters of the
      EfficientNet model. These will be passed in to the "from_name" function.
    loss: The configuration for loss. Defaults to a categorical cross entropy
      implementation.
    optimizer: The configuration for optimizations. Defaults to an RMSProp
      configuration.
    learning_rate: The configuration for learning rate. Defaults to an
      exponential configuration.
  """
  name: str = 'EfficientNet'
  num_classes: int = 1000
  model_params: base_config.Config = dataclasses.field(
      default_factory=lambda: {
          'model_name': 'efficientnet-b0',
          'model_weights_path': '',
          'weights_format': 'saved_model',
          'overrides': {
              'batch_norm': 'default',
              'rescale_input': True,
              'num_classes': 1000,
              'activation': 'swish',
              'dtype': 'float32',
          }
      })
  loss: base_configs.LossConfig = base_configs.LossConfig(
      name='categorical_crossentropy', label_smoothing=0.1)
  optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
      name='rmsprop',
      decay=0.9,
      epsilon=0.001,
      momentum=0.9,
      moving_average_decay=None)
  learning_rate: base_configs.LearningRateConfig = base_configs.LearningRateConfig(  # pylint: disable=line-too-long
      name='exponential',
      initial_lr=0.008,
      decay_epochs=2.4,
      decay_rate=0.97,
      warmup_epochs=5,
      scale_by_batch_size=1. / 128.,
      staircase=True)
 def test_initialize(self, dtype):
   config = base_configs.ExperimentConfig(
       runtime=base_configs.RuntimeConfig(
           enable_eager=False,
           enable_xla=False,
           gpu_threads_enabled=True,
           per_gpu_thread_count=1,
           gpu_thread_mode='gpu_private',
           num_gpus=1,
           dataset_num_private_threads=1,
       ),
       train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
       model=base_configs.ModelConfig(
           loss=base_configs.LossConfig(loss_scale='dynamic')),
   )
   classifier_trainer.initialize(config)
  def test_initialize(self, dtype):
    config = base_configs.ExperimentConfig(
        runtime=base_configs.RuntimeConfig(
            run_eagerly=False,
            enable_xla=False,
            gpu_threads_enabled=True,
            per_gpu_thread_count=1,
            gpu_thread_mode='gpu_private',
            num_gpus=1,
            dataset_num_private_threads=1,
        ),
        train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
        model=base_configs.ModelConfig(
            loss=base_configs.LossConfig(loss_scale='dynamic')),
    )

    class EmptyClass:
      pass
    fake_ds_builder = EmptyClass()
    fake_ds_builder.dtype = dtype
    fake_ds_builder.config = EmptyClass()
    fake_ds_builder.config.data_format = None
    classifier_trainer.initialize(config, fake_ds_builder)