def from_config(cls, config, custom_objects=None):
     config = config.copy()  # Make a copy, since we mutate config
     config['optimizer'] = optimizers.deserialize(
         config['optimizer'], custom_objects=custom_objects)
     config['loss_scale'] = keras_loss_scale_module.deserialize(
         config['loss_scale'], custom_objects=custom_objects)
     return cls(**config)
Ejemplo n.º 2
0
 def test_serialization(self):
   loss_scale = loss_scale_module.DynamicLossScale(
       initial_loss_scale=1, increment_period=2, multiplier=3)
   config = loss_scale_module.serialize(loss_scale)
   loss_scale = loss_scale_module.deserialize(config)
   self.evaluate(variables.global_variables_initializer())
   self.assertEqual(self.evaluate(loss_scale()), 1)
   self.assertEqual(loss_scale.increment_period, 2)
   self.assertEqual(loss_scale.multiplier, 3)
Ejemplo n.º 3
0
 def test_serialization(self):
     loss_scale = loss_scale_module.DynamicLossScale(initial_loss_scale=1,
                                                     increment_period=2,
                                                     multiplier=3)
     config = loss_scale_module.serialize(loss_scale)
     loss_scale = loss_scale_module.deserialize(config)
     self.evaluate(variables.global_variables_initializer())
     self.assertEqual(self.evaluate(loss_scale()), 1)
     self.assertEqual(loss_scale.increment_period, 2)
     self.assertEqual(loss_scale.multiplier, 3)
Ejemplo n.º 4
0
 def from_config(cls, config, custom_objects=None):
     if 'loss_scale' in config and isinstance(config['loss_scale'], dict):
         config = config.copy()
         config['loss_scale'] = keras_loss_scale_module.deserialize(
             config['loss_scale'], custom_objects=custom_objects)
     return cls(**config)
Ejemplo n.º 5
0
 def test_serialization(self):
     loss_scale = loss_scale_module.get(123)
     config = loss_scale_module.serialize(loss_scale)
     loss_scale = loss_scale_module.deserialize(config)
     self.assertEqual(self.evaluate(loss_scale()), 123.)
Ejemplo n.º 6
0
  def __init__(self, optimizer, loss_scale):
    warn_msg_prefix = (
        'tf.keras.mixed_precision.experimental.LossScaleOptimizer is '
        'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer '
        'instead. ')

    if isinstance(loss_scale, dict):
      loss_scale = keras_loss_scale_module.deserialize(loss_scale)

    if isinstance(loss_scale, (int, float)):
      tf_logging.warn(
          warn_msg_prefix + 'For example\n'
          '  opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer('
          'opt, dynamic=False, initial_scale={})'.format(loss_scale))
      super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
                                                 initial_scale=loss_scale)
    elif isinstance(loss_scale, loss_scale_module.FixedLossScale):
      ls_val = loss_scale._loss_scale_value  # pylint: disable=protected-access
      tf_logging.warn(
          warn_msg_prefix + 'For example\n'
          '  opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer('
          'opt, dynamic=False, initial_scale={})'.format(ls_val))
      super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
                                                 initial_scale=ls_val)
    elif loss_scale == 'dynamic':
      tf_logging.warn(
          warn_msg_prefix + 'For example\n'
          '  opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer('
          'opt)')
      super(LossScaleOptimizerV1, self).__init__(optimizer)
    elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
      kwargs = {}
      extra_arguments = ''
      if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE:
        kwargs['initial_scale'] = loss_scale.initial_loss_scale
        extra_arguments += (', initial_scale=%s' %
                            loss_scale.initial_loss_scale)
      if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS:
        kwargs['dynamic_growth_steps'] = loss_scale.increment_period
        extra_arguments += (', dynamic_growth_steps=%s' %
                            loss_scale.increment_period)
      if loss_scale.multiplier != 2:
        raise ValueError('When passing a DynamicLossScale to "loss_scale", '
                         'DynamicLossScale.multiplier must be 2. Got: %s'
                         % (loss_scale,))
      tf_logging.warn(
          warn_msg_prefix +
          'Note that the non-experimental LossScaleOptimizer does not take a '
          'DynamicLossScale but instead takes the dynamic configuration '
          'directly in the constructor. For example:\n'
          '  opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer('
          'opt{})\n'.format(extra_arguments))
      super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs)
    elif isinstance(loss_scale, loss_scale_module.LossScale):
      raise TypeError('Passing a LossScale that is not a FixedLossScale or a '
                      'DynamicLossScale is no longer supported. Got: {}'
                      .format(loss_scale))
    else:
      raise ValueError('Invalid value passed to loss_scale. loss_scale '
                       'must be the string "dynamic" (recommended), an int, '
                       'a float, a FixedLossScale, or a DynamicLossScale. Got '
                       'value: {}'.format(loss_scale))
Ejemplo n.º 7
0
 def test_serialization(self):
   loss_scale = loss_scale_module.get(123)
   config = loss_scale_module.serialize(loss_scale)
   loss_scale = loss_scale_module.deserialize(config)
   self.assertEqual(self.evaluate(loss_scale()), 123.)