def from_config(cls, config, custom_objects=None):
    config = config.copy()  # Make a copy, since we mutate config

    # If loss_scale is in config, we assume we are deserializing a
    # LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are
    # deserializing a LossScaleOptimizer from TF 2.4 or above.
    if 'loss_scale' in config:
      config['loss_scale'] = keras_loss_scale_module.deserialize(
          config['loss_scale'])
      if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale)
          and config['loss_scale'].multiplier != 2):
        raise ValueError('Cannot deserialize LossScaleOptimizer with a '
                         'DynamicLossScale whose multiplier is not 2. Got '
                         'DynamicLossScale: %s' % (config['loss_scale'],))
      config['optimizer'] = optimizers.deserialize(
          config['optimizer'], custom_objects=custom_objects)
      return cls(**config)

    # We convert the config, as generated by LossScaleOptimizer.get_config, to a
    # version that can be passed to LossScaleOptimizerV1.__init__
    if config['dynamic']:
      config['loss_scale'] = loss_scale_module.DynamicLossScale(
          config['initial_scale'], config['dynamic_growth_steps'], multiplier=2)
    else:
      config['loss_scale'] = loss_scale_module.FixedLossScale(
          config['initial_scale'])

    del config['dynamic']
    del config['initial_scale']
    del config['dynamic_growth_steps']
    config['optimizer'] = optimizers.deserialize(
        config.pop('inner_optimizer'), custom_objects=custom_objects)
    return cls(**config)
 def from_config(cls, config, custom_objects=None):
   config = config.copy()  # Make a copy, since we mutate config
   if 'loss_scale' in config:
     # If loss_scale is in config, we assume we are deserializing a
     # LossScaleOptimizer from TF 2.3 or below. We convert the config so it
     # can be deserialized in the current LossScaleOptimizer.
     loss_scale = keras_loss_scale_module.deserialize(
         config.pop('loss_scale'))
     if isinstance(loss_scale, loss_scale_module.FixedLossScale):
       config['dynamic'] = False
       config['initial_scale'] = loss_scale._loss_scale_value  # pylint: disable=protected-access
     elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
       config['dynamic'] = True
       config['initial_scale'] = loss_scale.initial_loss_scale
       config['dynamic_growth_steps'] = loss_scale.increment_period
       if loss_scale.multiplier != 2:
         raise ValueError('Cannot deserialize LossScaleOptimizer with a '
                          'DynamicLossScale whose multiplier is not 2. Got '
                          'DynamicLossScale: %s' % (loss_scale,))
     else:
       raise ValueError(
           'Serialized LossScaleOptimizers with a LossScale that is neither a '
           'FixedLossScale nor a DynamicLossScale can no longer be '
           'deserialized')
     config['inner_optimizer'] = config.pop('optimizer')
   config['inner_optimizer'] = optimizers.deserialize(
       config['inner_optimizer'], custom_objects=custom_objects)
   return cls(**config)
示例#3
0
 def from_config(cls, config, custom_objects=None):
     if 'loss_scale' in config and isinstance(config['loss_scale'], dict):
         config = config.copy()
         config['loss_scale'] = keras_loss_scale_module.deserialize(
             config['loss_scale'], custom_objects=custom_objects)
     return cls(**config)
  def __init__(self, optimizer, loss_scale):
    warn_msg_prefix = (
        'tf.keras.mixed_precision.experimental.LossScaleOptimizer is '
        'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer '
        'instead. ')

    if isinstance(loss_scale, dict):
      loss_scale = keras_loss_scale_module.deserialize(loss_scale)

    if isinstance(loss_scale, (int, float)):
      tf_logging.warn(
          warn_msg_prefix + 'For example:\n'
          '  opt = tf.keras.mixed_precision.LossScaleOptimizer('
          'opt, dynamic=False, initial_scale={})'.format(loss_scale))
      super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
                                                 initial_scale=loss_scale)
    elif isinstance(loss_scale, loss_scale_module.FixedLossScale):
      ls_val = loss_scale._loss_scale_value  # pylint: disable=protected-access
      tf_logging.warn(
          warn_msg_prefix + 'For example:\n'
          '  opt = tf.keras.mixed_precision.LossScaleOptimizer('
          'opt, dynamic=False, initial_scale={})'.format(ls_val))
      super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
                                                 initial_scale=ls_val)
    elif loss_scale == 'dynamic':
      tf_logging.warn(
          warn_msg_prefix + 'For example:\n'
          '  opt = tf.keras.mixed_precision.LossScaleOptimizer('
          'opt)')
      super(LossScaleOptimizerV1, self).__init__(optimizer)
    elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
      kwargs = {}
      extra_arguments = ''
      if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE:
        kwargs['initial_scale'] = loss_scale.initial_loss_scale
        extra_arguments += (', initial_scale=%s' %
                            loss_scale.initial_loss_scale)
      if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS:
        kwargs['dynamic_growth_steps'] = loss_scale.increment_period
        extra_arguments += (', dynamic_growth_steps=%s' %
                            loss_scale.increment_period)
      if loss_scale.multiplier != 2:
        raise ValueError('When passing a DynamicLossScale to "loss_scale", '
                         'DynamicLossScale.multiplier must be 2. Got: %s'
                         % (loss_scale,))
      tf_logging.warn(
          warn_msg_prefix +
          'Note that the non-experimental LossScaleOptimizer does not take a '
          'DynamicLossScale but instead takes the dynamic configuration '
          'directly in the constructor. For example:\n'
          '  opt = tf.keras.mixed_precision.LossScaleOptimizer('
          'opt{})\n'.format(extra_arguments))
      super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs)
    elif isinstance(loss_scale, loss_scale_module.LossScale):
      raise TypeError('Passing a LossScale that is not a FixedLossScale or a '
                      'DynamicLossScale is no longer supported. Got: {}'
                      .format(loss_scale))
    else:
      raise ValueError('Invalid value passed to loss_scale. loss_scale '
                       'must be the string "dynamic" (recommended), an int, '
                       'a float, a FixedLossScale, or a DynamicLossScale. Got '
                       'value: {}'.format(loss_scale))