def test_get_loss_scale(self, loss_scale, dtype, expected):
   config = base_configs.ExperimentConfig(
       runtime=base_configs.RuntimeConfig(
           loss_scale=loss_scale),
       train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
   ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
   self.assertEqual(ls, expected)
 def test_serialize_config(self):
   """Tests functionality for serializing data."""
   config = base_configs.ExperimentConfig()
   model_dir = self.get_temp_dir()
   classifier_trainer.serialize_config(params=config, model_dir=model_dir)
   saved_params_path = os.path.join(model_dir, 'params.yaml')
   self.assertTrue(os.path.exists(saved_params_path))
   tf.io.gfile.rmtree(model_dir)
Example #3
0
 def test_get_model_size(self, model, model_name, expected):
     config = base_configs.ExperimentConfig(
         model_name=model,
         model=base_configs.ModelConfig(model_params={
             'model_name': model_name,
         }, ))
     size = classifier_trainer.get_image_size_from_model(config)
     self.assertEqual(size, expected)
 def test_initialize(self, dtype):
   config = base_configs.ExperimentConfig(
       runtime=base_configs.RuntimeConfig(
           enable_eager=False,
           enable_xla=False,
           gpu_threads_enabled=True,
           per_gpu_thread_count=1,
           gpu_thread_mode='gpu_private',
           num_gpus=1,
           dataset_num_private_threads=1,
       ),
       train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
       model=base_configs.ModelConfig(
           loss=base_configs.LossConfig(loss_scale='dynamic')),
   )
   classifier_trainer.initialize(config)
Example #5
0
  def test_initialize(self, dtype):
    config = base_configs.ExperimentConfig(
        runtime=base_configs.RuntimeConfig(
            run_eagerly=False,
            enable_xla=False,
            per_gpu_thread_count=1,
            gpu_thread_mode='gpu_private',
            num_gpus=1,
            dataset_num_private_threads=1,
        ),
        train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
        model=base_configs.ModelConfig(),
    )

    class EmptyClass:
      pass
    fake_ds_builder = EmptyClass()
    fake_ds_builder.dtype = dtype
    fake_ds_builder.config = EmptyClass()
    classifier_trainer.initialize(config, fake_ds_builder)