class SqueezeNetModelConfig(base_configs.ModelConfig): """Configuration for the SqueezeNet model.""" name: str = 'SqueezeNet' num_classes: int = 1000 model_params: Mapping[str, Any] = dataclasses.field(default_factory=lambda: { 'num_classes': 1000, 'batch_size': None, }) loss: base_configs.LossConfig = base_configs.LossConfig( name='sparse_categorical_crossentropy') optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig( name='momentum', decay=0.9, epsilon=0.001, momentum=0.9, moving_average_decay=None) learning_rate: base_configs.LearningRateConfig = ( base_configs.LearningRateConfig(name='piecewise_constant_with_warmup', examples_per_epoch=1281167, warmup_epochs=_LR_WARMUP_EPOCHS, boundaries=_LR_BOUNDARIES, multipliers=_LR_MULTIPLIERS))
def test_initialize(self, dtype): config = base_configs.ExperimentConfig( runtime=base_configs.RuntimeConfig( run_eagerly=False, enable_xla=False, gpu_threads_enabled=True, per_gpu_thread_count=1, gpu_thread_mode='gpu_private', num_gpus=1, dataset_num_private_threads=1, ), train_dataset=dataset_factory.DatasetConfig(dtype=dtype), model=base_configs.ModelConfig( loss=base_configs.LossConfig(loss_scale='dynamic')), ) class EmptyClass: pass fake_ds_builder = EmptyClass() fake_ds_builder.dtype = dtype fake_ds_builder.config = EmptyClass() classifier_trainer.initialize(config, fake_ds_builder)