class ResNetImagenetConfig(base_configs.ExperimentConfig): """Base configuration to train resnet-50 on ImageNet.""" export: base_configs.ExportConfig = base_configs.ExportConfig() runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig() train_dataset: dataset_factory.DatasetConfig = \ dataset_factory.ImageNetConfig(split='train', one_hot=False, mean_subtract=True, standardize=True) validation_dataset: dataset_factory.DatasetConfig = \ dataset_factory.ImageNetConfig(split='validation', one_hot=False, mean_subtract=True, standardize=True) train: base_configs.TrainConfig = base_configs.TrainConfig( resume_checkpoint=True, epochs=90, steps=None, callbacks=base_configs.CallbacksConfig( enable_checkpoint_and_export=True, enable_tensorboard=True), metrics=['accuracy', 'top_5'], time_history=base_configs.TimeHistoryConfig(log_steps=100), tensorboard=base_configs.TensorboardConfig(track_lr=True, write_model_weights=False), set_epoch_loop=False) evaluation: base_configs.EvalConfig = base_configs.EvalConfig( epochs_between_evals=1, steps=None) model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
class EfficientNetImageNetConfig(base_configs.ExperimentConfig): """Base configuration to train efficientnet-b0 on ImageNet. Attributes: export: An `ExportConfig` instance runtime: A `RuntimeConfig` instance. dataset: A `DatasetConfig` instance. train: A `TrainConfig` instance. evaluation: An `EvalConfig` instance. model: A `ModelConfig` instance. """ export: base_configs.ExportConfig = base_configs.ExportConfig() runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig() train_dataset: dataset_factory.DatasetConfig = \ dataset_factory.ImageNetConfig(split='train') validation_dataset: dataset_factory.DatasetConfig = \ dataset_factory.ImageNetConfig(split='validation') train: base_configs.TrainConfig = base_configs.TrainConfig( resume_checkpoint=True, epochs=500, steps=None, callbacks=base_configs.CallbacksConfig( enable_checkpoint_and_export=True, enable_tensorboard=True), metrics=['accuracy', 'top_5'], time_history=base_configs.TimeHistoryConfig(log_steps=100), tensorboard=base_configs.TensorboardConfig(track_lr=True, write_model_weights=False), set_epoch_loop=False) evaluation: base_configs.EvalConfig = base_configs.EvalConfig( epochs_between_evals=1, steps=None) model: base_configs.ModelConfig = \ efficientnet_config.EfficientNetModelConfig()