def testGetterSetterInterface(self): def f(x, y): return 3 k = "blah" registry.Registries.optimizers[k] = f self.assertEqual(registry.optimizer(k), f) self.assertEqual(registry.Registries.optimizers[k], f) self.assertEqual(registry.Registries.optimizers[k], registry.optimizer(k))
def testRegistration(self): @registry.register_optimizer def my_optimizer(learning_rate, hparams): return 3 @registry.register_optimizer("my_other_optimizer") def another_optimizer(learning_rate, hparams): return 5 self.assertEqual(registry.optimizer("my_optimizer"), my_optimizer) self.assertEqual( registry.optimizer("my_other_optimizer"), another_optimizer)
def __init__(self, optimizer_name, lr, hparams, use_tpu=False): # pylint: disable=super-init-not-called tf.logging.info("Using optimizer %s", optimizer_name) mlperf_log.transformer_print(key=mlperf_log.OPT_NAME, value=optimizer_name, hparams=hparams) mlperf_log.transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA1, value=hparams.optimizer_adam_beta1, hparams=hparams) mlperf_log.transformer_print(key=mlperf_log.OPT_HP_ADAM_BETA2, value=hparams.optimizer_adam_beta2, hparams=hparams) mlperf_log.transformer_print(key=mlperf_log.OPT_HP_ADAM_EPSILON, value=hparams.optimizer_adam_epsilon, hparams=hparams) self._opt = registry.optimizer(optimizer_name)(lr, hparams) if _mixed_precision_is_enabled(hparams): if not hparams.mixed_precision_optimizer_loss_scaler: tf.logging.warning( "Using mixed precision without a loss scaler will " "likely cause numerical errors.") elif hparams.mixed_precision_optimizer_loss_scaler != "exponential": raise ValueError("Mixed precision training only supports the " "exponential loss scaler") else: tf.logging.info( ("Using Exponential Update Loss Scaler with", "init loss scale of {}".format( hparams.mixed_precision_optimizer_init_loss_scale))) manager = contrib.mixed_precision( ).ExponentialUpdateLossScaleManager( init_loss_scale=hparams. mixed_precision_optimizer_init_loss_scale, incr_every_n_steps=2000, decr_every_n_nan_or_inf=2, incr_ratio=2, decr_ratio=0.5) self._opt = contrib.mixed_precision().LossScaleOptimizer( self._opt, manager) self._zero_grads = hparams.optimizer_zero_grads
def testUnknownOptimizer(self): with self.assertRaisesRegexp(KeyError, "never registered"): registry.optimizer("not_registered_optimizer")