Example #1
0
class ResNetModelConfig(base_configs.ModelConfig):
    """Configuration for the ResNet model."""
    name: str = 'ResNet'
    num_classes: int = 1000
    model_params: base_config.Config = dataclasses.field(
        default_factory=lambda: {
            'num_classes': 1000,
            'batch_size': None,
            'use_l2_regularizer': True,
            'rescale_inputs': False,
        })
    loss: base_configs.LossConfig = base_configs.LossConfig(
        name='sparse_categorical_crossentropy')
    optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
        name='momentum',
        decay=0.9,
        epsilon=0.001,
        momentum=0.9,
        moving_average_decay=None)
    learning_rate: base_configs.LearningRateConfig = (
        base_configs.LearningRateConfig(
            name='stepwise',
            initial_lr=0.1,
            examples_per_epoch=1281167,
            boundaries=[30, 60, 80],
            warmup_epochs=5,
            scale_by_batch_size=1. / 256.,
            multipliers=[0.1 / 256, 0.01 / 256, 0.001 / 256, 0.0001 / 256]))
Example #2
0
class ResNetModelConfig(base_configs.ModelConfig):
  """Configuration for the ResNet model."""
  name: str = 'ResNet'
  num_classes: int = 1000
  model_params: base_config.Config = dataclasses.field(
      default_factory=lambda: {
          'num_classes': 1000,
          'batch_size': None,
          'use_l2_regularizer': True,
          'rescale_inputs': False,
      })
  loss: base_configs.LossConfig = base_configs.LossConfig(
      name='sparse_categorical_crossentropy')
  optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
      name='momentum',
      decay=0.9,
      epsilon=0.001,
      momentum=0.9,
      moving_average_decay=None)
  learning_rate: base_configs.LearningRateConfig = (
      base_configs.LearningRateConfig(
          name='piecewise_constant_with_warmup',
          examples_per_epoch=1281167,
          warmup_epochs=_RESNET_LR_WARMUP_EPOCHS,
          boundaries=_RESNET_LR_BOUNDARIES,
          multipliers=_RESNET_LR_MULTIPLIERS))
class EfficientNetModelConfig(base_configs.ModelConfig):
  """Configuration for the EfficientNet model.

  This configuration will default to settings used for training efficientnet-b0
  on a v3-8 TPU on ImageNet.

  Attributes:
    name: The name of the model. Defaults to 'EfficientNet'.
    num_classes: The number of classes in the model.
    model_params: A dictionary that represents the parameters of the
      EfficientNet model. These will be passed in to the "from_name" function.
    loss: The configuration for loss. Defaults to a categorical cross entropy
      implementation.
    optimizer: The configuration for optimizations. Defaults to an RMSProp
      configuration.
    learning_rate: The configuration for learning rate. Defaults to an
      exponential configuration.
  """
  name: str = 'EfficientNet'
  num_classes: int = 1000
  model_params: base_config.Config = dataclasses.field(
      default_factory=lambda: {
          'model_name': 'efficientnet-b0',
          'model_weights_path': '',
          'weights_format': 'saved_model',
          'overrides': {
              'batch_norm': 'default',
              'rescale_input': True,
              'num_classes': 1000,
              'activation': 'swish',
              'dtype': 'float32',
          }
      })
  loss: base_configs.LossConfig = base_configs.LossConfig(
      name='categorical_crossentropy', label_smoothing=0.1)
  optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
      name='rmsprop',
      decay=0.9,
      epsilon=0.001,
      momentum=0.9,
      moving_average_decay=None)
  learning_rate: base_configs.LearningRateConfig = base_configs.LearningRateConfig(  # pylint: disable=line-too-long
      name='exponential',
      initial_lr=0.008,
      decay_epochs=2.4,
      decay_rate=0.97,
      warmup_epochs=5,
      scale_by_batch_size=1. / 128.,
      staircase=True)
Example #4
0
    def test_learning_rate_without_decay_or_warmups(self):
        params = base_configs.LearningRateConfig(name='exponential',
                                                 initial_lr=0.01,
                                                 decay_rate=0.01,
                                                 decay_epochs=None,
                                                 warmup_epochs=None,
                                                 scale_by_batch_size=0.01,
                                                 examples_per_epoch=1,
                                                 boundaries=[0],
                                                 multipliers=[0, 1])
        batch_size = 1
        train_steps = 1

        lr = optimizer_factory.build_learning_rate(params=params,
                                                   batch_size=batch_size,
                                                   train_steps=train_steps)
        self.assertTrue(
            issubclass(type(lr),
                       tf.keras.optimizers.schedules.LearningRateSchedule))
Example #5
0
    def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):
        """Basic smoke test for syntax."""
        params = base_configs.LearningRateConfig(name=lr_decay_type,
                                                 initial_lr=0.01,
                                                 decay_rate=0.01,
                                                 decay_epochs=1,
                                                 warmup_epochs=1,
                                                 scale_by_batch_size=0.01,
                                                 examples_per_epoch=1,
                                                 boundaries=[0],
                                                 multipliers=[0, 1])
        batch_size = 1
        train_steps = 1

        lr = optimizer_factory.build_learning_rate(params=params,
                                                   batch_size=batch_size,
                                                   train_steps=train_steps)
        self.assertTrue(
            issubclass(type(lr),
                       tf.keras.optimizers.schedules.LearningRateSchedule))