def test_configure_optimizer(self, mixed_precision_dtype, loss_scale): config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig( mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale), trainer=cfg.TrainerConfig( optimizer_config=cfg.OptimizationConfig({ 'optimizer': { 'type': 'sgd' }, 'learning_rate': { 'type': 'constant' }, 'use_experimental_api': { 'type': False }, }))) trainer = self.create_test_trainer(config) if mixed_precision_dtype != 'float16': self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) elif mixed_precision_dtype == 'float16' and loss_scale is None: self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD) else: self.assertIsInstance( trainer.optimizer, tf.keras.mixed_precision.LossScaleOptimizer) metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) self.assertIn('training_loss', metrics)
def get_optimizer(self, stage_id: int) -> tf.keras.optimizers.Optimizer: optimizer_type = 'sgd' if stage_id == 0 else 'adamw' optimizer_config = cfg.OptimizationConfig({ 'optimizer': {'type': optimizer_type}, 'learning_rate': {'type': 'constant'}}) opt_factory = optimization.OptimizerFactory(optimizer_config) return opt_factory.build_optimizer(opt_factory.build_learning_rate())
def setUp(self): super().setUp() self._config = cfg.ExperimentConfig(trainer=cfg.TrainerConfig( optimizer_config=cfg.OptimizationConfig({ 'optimizer': { 'type': 'sgd' }, 'learning_rate': { 'type': 'constant' } })))
def test_export_best_ckpt(self, distribution): config = cfg.ExperimentConfig(trainer=cfg.TrainerConfig( best_checkpoint_export_subdir='best_ckpt', best_checkpoint_eval_metric='acc', optimizer_config=cfg.OptimizationConfig({ 'optimizer': { 'type': 'sgd' }, 'learning_rate': { 'type': 'constant' } }))) model_dir = self.get_temp_dir() trainer = self.create_test_trainer(config, model_dir=model_dir) trainer.train(tf.convert_to_tensor(1, dtype=tf.int32)) trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32)) self.assertTrue( tf.io.gfile.exists( os.path.join(model_dir, 'best_ckpt', 'info.json')))