Exemplo n.º 1
0
class ResNetModelConfig(base_configs.ModelConfig):
    """Configuration for the ResNet model."""
    name: str = 'ResNet'
    num_classes: int = 1000
    model_params: base_config.Config = dataclasses.field(
        # pylint: disable=g-long-lambda
        default_factory=lambda: {
            'num_classes': 1000,
            'batch_size': None,
            'use_l2_regularizer': True,
            'rescale_inputs': False,
        })
    # pylint: enable=g-long-lambda
    loss: base_configs.LossConfig = base_configs.LossConfig(
        name='sparse_categorical_crossentropy')
    optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
        name='momentum',
        decay=0.9,
        epsilon=0.001,
        momentum=0.9,
        moving_average_decay=None)
    learning_rate: base_configs.LearningRateConfig = (
        base_configs.LearningRateConfig(
            name='stepwise',
            initial_lr=0.1,
            examples_per_epoch=1281167,
            boundaries=[30, 60, 80],
            warmup_epochs=5,
            scale_by_batch_size=1. / 256.,
            multipliers=[0.1 / 256, 0.01 / 256, 0.001 / 256, 0.0001 / 256]))
Exemplo n.º 2
0
class EfficientNetModelConfig(base_configs.ModelConfig):
    """Configuration for the EfficientNet model.

  This configuration will default to settings used for training efficientnet-b0
  on a v3-8 TPU on ImageNet.

  Attributes:
    name: The name of the model. Defaults to 'EfficientNet'.
    num_classes: The number of classes in the model.
    model_params: A dictionary that represents the parameters of the
      EfficientNet model. These will be passed in to the "from_name" function.
    loss: The configuration for loss. Defaults to a categorical cross entropy
      implementation.
    optimizer: The configuration for optimizations. Defaults to an RMSProp
      configuration.
    learning_rate: The configuration for learning rate. Defaults to an
      exponential configuration.
  """
    name: str = 'EfficientNet'
    num_classes: int = 1000
    model_params: base_config.Config = dataclasses.field(
        default_factory=lambda: {
            'model_name': 'efficientnet-b0',
            'model_weights_path': '',
            'weights_format': 'saved_model',
            'overrides': {
                'batch_norm': 'default',
                'rescale_input': True,
                'num_classes': 1000,
                'activation': 'swish',
                'dtype': 'float32',
            }
        })
    loss: base_configs.LossConfig = base_configs.LossConfig(
        name='categorical_crossentropy', label_smoothing=0.1)
    optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
        name='rmsprop',
        decay=0.9,
        epsilon=0.001,
        momentum=0.9,
        moving_average_decay=None)
    learning_rate: base_configs.LearningRateConfig = base_configs.LearningRateConfig(  # pylint: disable=line-too-long
        name='exponential',
        initial_lr=0.008,
        decay_epochs=2.4,
        decay_rate=0.97,
        warmup_epochs=5,
        scale_by_batch_size=1. / 128.,
        staircase=True)