예제 #1
0
  def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
    config = cfg.ExperimentConfig(
        runtime=cfg.RuntimeConfig(
            mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
        trainer=cfg.TrainerConfig(
            optimizer_config=cfg.OptimizationConfig({
                'optimizer': {
                    'type': 'sgd'
                },
                'learning_rate': {
                    'type': 'constant'
                }
            })))
    trainer = self.create_test_trainer(config)
    if mixed_precision_dtype != 'float16':
      self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
    elif mixed_precision_dtype == 'float16' and loss_scale is None:
      self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
    else:
      self.assertIsInstance(
          trainer.optimizer,
          tf.keras.mixed_precision.experimental.LossScaleOptimizer)

    metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
    self.assertIn('training_loss', metrics)
예제 #2
0
 def test_export_best_ckpt(self, distribution):
   config = cfg.ExperimentConfig(
       trainer=cfg.TrainerConfig(
           best_checkpoint_export_subdir='best_ckpt',
           best_checkpoint_eval_metric='acc',
           optimizer_config=cfg.OptimizationConfig({
               'optimizer': {
                   'type': 'sgd'
               },
               'learning_rate': {
                   'type': 'constant'
               }
           })))
   model_dir = self.get_temp_dir()
   task = mock_task.MockTask(config.task, logging_dir=model_dir)
   ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
   trainer = trainer_lib.Trainer(
       config,
       task,
       model=task.build_model(),
       checkpoint_exporter=ckpt_exporter)
   trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
   trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
   self.assertTrue(
       tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
예제 #3
0
 def setUp(self):
     super().setUp()
     self._config = cfg.ExperimentConfig(trainer=cfg.TrainerConfig(
         optimizer_config=cfg.OptimizationConfig({
             'optimizer': {
                 'type': 'sgd'
             },
             'learning_rate': {
                 'type': 'constant'
             }
         })))