def test_get_loss_scale(self, loss_scale, dtype, expected): config = base_configs.ExperimentConfig( runtime=base_configs.RuntimeConfig( loss_scale=loss_scale), train_dataset=dataset_factory.DatasetConfig(dtype=dtype)) ls = classifier_trainer.get_loss_scale(config, fp16_default=128) self.assertEqual(ls, expected)
def test_serialize_config(self): """Tests functionality for serializing data.""" config = base_configs.ExperimentConfig() model_dir = self.get_temp_dir() classifier_trainer.serialize_config(params=config, model_dir=model_dir) saved_params_path = os.path.join(model_dir, 'params.yaml') self.assertTrue(os.path.exists(saved_params_path)) tf.io.gfile.rmtree(model_dir)
def test_get_model_size(self, model, model_name, expected): config = base_configs.ExperimentConfig( model_name=model, model=base_configs.ModelConfig(model_params={ 'model_name': model_name, }, )) size = classifier_trainer.get_image_size_from_model(config) self.assertEqual(size, expected)
def test_initialize(self, dtype): config = base_configs.ExperimentConfig( runtime=base_configs.RuntimeConfig( enable_eager=False, enable_xla=False, gpu_threads_enabled=True, per_gpu_thread_count=1, gpu_thread_mode='gpu_private', num_gpus=1, dataset_num_private_threads=1, ), train_dataset=dataset_factory.DatasetConfig(dtype=dtype), model=base_configs.ModelConfig( loss=base_configs.LossConfig(loss_scale='dynamic')), ) classifier_trainer.initialize(config)
def test_initialize(self, dtype): config = base_configs.ExperimentConfig( runtime=base_configs.RuntimeConfig( run_eagerly=False, enable_xla=False, per_gpu_thread_count=1, gpu_thread_mode='gpu_private', num_gpus=1, dataset_num_private_threads=1, ), train_dataset=dataset_factory.DatasetConfig(dtype=dtype), model=base_configs.ModelConfig(), ) class EmptyClass: pass fake_ds_builder = EmptyClass() fake_ds_builder.dtype = dtype fake_ds_builder.config = EmptyClass() classifier_trainer.initialize(config, fake_ds_builder)