def from_config(cls, config, custom_objects=None): config = config.copy() # Make a copy, since we mutate config config['optimizer'] = optimizers.deserialize( config['optimizer'], custom_objects=custom_objects) config['loss_scale'] = keras_loss_scale_module.deserialize( config['loss_scale'], custom_objects=custom_objects) return cls(**config)
def test_serialization(self): loss_scale = loss_scale_module.DynamicLossScale( initial_loss_scale=1, increment_period=2, multiplier=3) config = loss_scale_module.serialize(loss_scale) loss_scale = loss_scale_module.deserialize(config) self.evaluate(variables.global_variables_initializer()) self.assertEqual(self.evaluate(loss_scale()), 1) self.assertEqual(loss_scale.increment_period, 2) self.assertEqual(loss_scale.multiplier, 3)
def test_serialization(self): loss_scale = loss_scale_module.DynamicLossScale(initial_loss_scale=1, increment_period=2, multiplier=3) config = loss_scale_module.serialize(loss_scale) loss_scale = loss_scale_module.deserialize(config) self.evaluate(variables.global_variables_initializer()) self.assertEqual(self.evaluate(loss_scale()), 1) self.assertEqual(loss_scale.increment_period, 2) self.assertEqual(loss_scale.multiplier, 3)
def from_config(cls, config, custom_objects=None): if 'loss_scale' in config and isinstance(config['loss_scale'], dict): config = config.copy() config['loss_scale'] = keras_loss_scale_module.deserialize( config['loss_scale'], custom_objects=custom_objects) return cls(**config)
def test_serialization(self): loss_scale = loss_scale_module.get(123) config = loss_scale_module.serialize(loss_scale) loss_scale = loss_scale_module.deserialize(config) self.assertEqual(self.evaluate(loss_scale()), 123.)
def __init__(self, optimizer, loss_scale): warn_msg_prefix = ( 'tf.keras.mixed_precision.experimental.LossScaleOptimizer is ' 'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer ' 'instead. ') if isinstance(loss_scale, dict): loss_scale = keras_loss_scale_module.deserialize(loss_scale) if isinstance(loss_scale, (int, float)): tf_logging.warn( warn_msg_prefix + 'For example\n' ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' 'opt, dynamic=False, initial_scale={})'.format(loss_scale)) super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, initial_scale=loss_scale) elif isinstance(loss_scale, loss_scale_module.FixedLossScale): ls_val = loss_scale._loss_scale_value # pylint: disable=protected-access tf_logging.warn( warn_msg_prefix + 'For example\n' ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' 'opt, dynamic=False, initial_scale={})'.format(ls_val)) super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, initial_scale=ls_val) elif loss_scale == 'dynamic': tf_logging.warn( warn_msg_prefix + 'For example\n' ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' 'opt)') super(LossScaleOptimizerV1, self).__init__(optimizer) elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): kwargs = {} extra_arguments = '' if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE: kwargs['initial_scale'] = loss_scale.initial_loss_scale extra_arguments += (', initial_scale=%s' % loss_scale.initial_loss_scale) if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS: kwargs['dynamic_growth_steps'] = loss_scale.increment_period extra_arguments += (', dynamic_growth_steps=%s' % loss_scale.increment_period) if loss_scale.multiplier != 2: raise ValueError('When passing a DynamicLossScale to "loss_scale", ' 'DynamicLossScale.multiplier must be 2. Got: %s' % (loss_scale,)) tf_logging.warn( warn_msg_prefix + 'Note that the non-experimental LossScaleOptimizer does not take a ' 'DynamicLossScale but instead takes the dynamic configuration ' 'directly in the constructor. For example:\n' ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' 'opt{})\n'.format(extra_arguments)) super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs) elif isinstance(loss_scale, loss_scale_module.LossScale): raise TypeError('Passing a LossScale that is not a FixedLossScale or a ' 'DynamicLossScale is no longer supported. Got: {}' .format(loss_scale)) else: raise ValueError('Invalid value passed to loss_scale. loss_scale ' 'must be the string "dynamic" (recommended), an int, ' 'a float, a FixedLossScale, or a DynamicLossScale. Got ' 'value: {}'.format(loss_scale))
def test_serialization(self): loss_scale = loss_scale_module.get(123) config = loss_scale_module.serialize(loss_scale) loss_scale = loss_scale_module.deserialize(config) self.assertEqual(self.evaluate(loss_scale()), 123.)