예제 #1
0
class VGGImagenetConfig(base_configs.ExperimentConfig):
    """Base configuration to train vgg-16 on ImageNet."""
    export: base_configs.ExportConfig = base_configs.ExportConfig()
    runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
    train_dataset: dataset_factory.DatasetConfig = dataset_factory.ImageNetConfig(
        split='train', one_hot=False, mean_subtract=True, standardize=True)
    validation_dataset: dataset_factory.DatasetConfig = dataset_factory.ImageNetConfig(
        split='validation',
        one_hot=False,
        mean_subtract=True,
        standardize=True)
    train: base_configs.TrainConfig = base_configs.TrainConfig(
        resume_checkpoint=True,
        epochs=90,
        steps=None,
        callbacks=base_configs.CallbacksConfig(
            enable_checkpoint_and_export=True, enable_tensorboard=True),
        metrics=['accuracy', 'top_5'],
        time_history=base_configs.TimeHistoryConfig(log_steps=100),
        tensorboard=base_configs.TensorBoardConfig(track_lr=True,
                                                   write_model_weights=False),
        set_epoch_loop=False)
    evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
        epochs_between_evals=1, steps=None)
    model: base_configs.ModelConfig = vgg_config.VGGModelConfig()
예제 #2
0
class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
    """Base configuration to train efficientnet-b0 on ImageNet.

  Attributes:
    export: An `ExportConfig` instance
    runtime: A `RuntimeConfig` instance.
    dataset: A `DatasetConfig` instance.
    train: A `TrainConfig` instance.
    evaluation: An `EvalConfig` instance.
    model: A `ModelConfig` instance.
  """
    export: base_configs.ExportConfig = base_configs.ExportConfig()
    runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
    train_dataset: dataset_factory.DatasetConfig = dataset_factory.ImageNetConfig(
        split='train')
    validation_dataset: dataset_factory.DatasetConfig = dataset_factory.ImageNetConfig(
        split='validation')
    train: base_configs.TrainConfig = base_configs.TrainConfig(
        resume_checkpoint=True,
        epochs=500,
        steps=None,
        callbacks=base_configs.CallbacksConfig(
            enable_checkpoint_and_export=True, enable_tensorboard=True),
        metrics=['accuracy', 'top_5'],
        time_history=base_configs.TimeHistoryConfig(log_steps=100),
        tensorboard=base_configs.TensorBoardConfig(track_lr=True,
                                                   write_model_weights=False),
        set_epoch_loop=False)
    evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
        epochs_between_evals=1, steps=None)
    model: base_configs.ModelConfig = efficientnet_config.EfficientNetModelConfig(
    )