示例#1
0
  def __init__(self, name, loss_scale=USE_DEFAULT):
    """Constructs the policy.

    The `name` argument determines the compute and variable dtype, and has no
    additional effect on the Policy. The compute and variable dtypes can only be
    specified through `name`, and cannot be specified directly.

    Args:
      name: A string. Can be one of the following values:
        * Any dtype name, such as 'float32' or 'float64'. Both the variable and
          compute dtypes will be that dtype.
        * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
          bfloat16, while the variable dtype is float32. With 'mixed_float16',
          a dynamic loss scale is used. These policies are used for mixed
          precision training.
        * 'infer' (deprecated): Infer the compute and variable dtype from the
          input dtype.
      loss_scale: A `tf.mixed_precision.experimental.LossScale`, or a value
        convertible to one such as "dynamic". Defaults to using no loss scaling
        unless `name` is "mixed_float16", in which case this defaults to
        "dynamic". Only `tf.keras.Model`s, not layers, use the loss scale, and
        it is only used during `Model.fit` or `Model.train_on_batch`.

    """
    if isinstance(name, dtypes.DType):
      raise TypeError("'name' must be a string, not a DType. "
                      "Instead, pass DType.name. Got: %s" % (name.name,))
    elif not isinstance(name, six.string_types):
      raise TypeError("'name' must be a string, but got: %s" % (name,))
    if name == 'infer_float32_vars':
      # For backwards compatibility. TODO(reedwm): Remove this.
      name = 'infer_with_float32_vars'
    if name == 'float32_with_float32_vars':
      # Doesn't affect correctness, but causes "float32" instead of
      # "float32_with_float32_vars" to be printed in __repr__.
      name = 'float32'
    self._name = name
    self._compute_dtype, self._variable_dtype = self._parse_name(name)

    if name.endswith('_with_float32_vars') and self._warn_about_float32_vars:
      warning = ("WARNING: The '%s' policy is deprecated and will be removed "
                 "in TensorFlow 2.1." % name)
      if name == 'infer_with_float32_vars':
        warning += (" Please use the 'mixed_float16' or 'mixed_bfloat16' "
                    "policy instead.")
      elif name == 'float16_with_float32_vars':
        warning += " Please use the 'mixed_float16' policy instead."
      elif name == 'bfloat16_with_float32_vars':
        warning += " Please use the 'mixed_bfloat16' policy instead."
      tf_logging.warn(warning)

    if loss_scale == USE_DEFAULT:
      loss_scale = 'dynamic' if name == 'mixed_float16' else None
    if loss_scale and self._compute_dtype not in (None, 'float16'):
      tf_logging.warn('Creating a Policy with a loss scale is only useful for '
                      'float16 policies. You passed loss_scale=%r for policy '
                      '%s. Consider not passing any loss_scale instead.' %
                      (loss_scale, name))
    self._loss_scale = loss_scale_module.get(loss_scale)
示例#2
0
    def __init__(self, name, loss_scale=USE_DEFAULT):
        """Constructs the policy.

    The `name` argument determines the compute and variable dtype, and has no
    additional effect on the Policy. The compute and variable dtypes can only be
    specified through `name`, and cannot be specified directly.

    Args:
      name: A string. Can be one of the following values:
        * Any dtype name, such as 'float32' or 'float64'. Both the variable and
          compute dtypes will be that dtype.
        * '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute
          dtype will be <dtype>, while the variable dtype is float32. This can
          be used for mixed precision, which uses float16 or bfloat16 for most
          computations, and float32 for variables, but it is recommended to use
          the 'mixed_float16' or 'mixed_bfloat16' policies instead.
        * 'mixed_float16' or 'mixed_bfloat16': Similar to
          'float16_with_float32_vars' or 'bfloat16_with_float32_vars'
          respectively. 'mixed_float16' is identical to
          'float16_with_float32_vars' except the loss_scale is dynamic by
          default. 'mixed_bfloat16' is currently identical to
          'bfloat16_with_float32_vars'. More changes may be added to these mixed
          policies in the future, to further differentiate them from
          [b]float16_with_float32_vars.
        * 'infer' or 'infer_with_float32_vars' (deprecated): Infer the
          computation dtype from the input dtype.
      loss_scale: A `tf.train.experimental.LossScale`, or a value convertible to
        one such as "dynamic". Defaults to using no loss scaling unless `name`
        is "mixed_float16", in which case this defaults to "dynamic". Only
        `tf.keras.Model`s, not layers, use the loss scale, and it is only used
        during `Model.fit` or `Model.train_on_batch`.

    """
        if isinstance(name, dtypes.DType):
            raise TypeError("'name' must be a string, not a DType. "
                            "Instead, pass DType.name. Got: %s" %
                            (name.name, ))
        elif not isinstance(name, six.string_types):
            raise TypeError("'name' must be a string, but got: %s" % (name, ))
        if name == 'infer_float32_vars':
            # For backwards compatibility. TODO(reedwm): Remove this.
            name = 'infer_with_float32_vars'
        if name == 'float32_with_float32_vars':
            # Doesn't affect correctness, but causes "float32" instead of
            # "float32_with_float32_vars" to be printed in __repr__.
            name = 'float32'
        self._name = name
        self._compute_dtype, self._variable_dtype = self._parse_name(name)

        if loss_scale == USE_DEFAULT:
            loss_scale = 'dynamic' if name == 'mixed_float16' else None
        if loss_scale and self._compute_dtype not in (None, 'float16'):
            tf_logging.warn(
                'Creating a Policy with a loss scale is only useful for '
                'float16 policies. You passed loss_scale=%r for policy '
                '%s. Consider not passing any loss_scale instead.' %
                (loss_scale, name))
        self._loss_scale = loss_scale_module.get(loss_scale)
    def __call__(opt, loss_scale):
        """Initializes a loss scaled optimizer.

        Args:
          opt: The Optimizer instance to wrap.
          loss_scale: The loss scale to scale the loss and gradients. This can
            either be an int/float to use a fixed loss scale, the string "dynamic"
            to use dynamic loss scaling, or an instance of a LossScale. The string
            "dynamic" equivalent to passing `DynamicLossScale()`, and passing an
            int/float is equivalent to passing a FixedLossScale with the given loss
            scale.
        Returns:
          Keras Optimizer with loss scaling
        """

        opt._loss_scale = loss_scale_module.get(loss_scale)

        for weight in loss_scale_module.get_loss_scale_weights(
                opt._loss_scale):
            # We cannot call `track_variable` in the LossScale class itself, because a
            # file outside of Keras cannot depend on a Keras file. Calling it here
            # instead is OK, because a variable only needs to be tracked if used with
            # a Keras class, and the only way to use LossScale with a Keras class is
            # through the LossScaleOptimizer.
            backend.track_variable(weight)

        opt._track_trackable(opt._loss_scale, 'loss_scale')

        class BaseOptimizer(object):
            _class = opt.__class__
            _classname = "%s.%s" % (opt.__module__, opt.__class__.__name__)
            _compute_gradients = opt._compute_gradients
            get_gradients = opt.get_gradients
            apply_gradients = opt.apply_gradients
            get_config = opt.get_config
            from_config = opt.from_config

        opt.loss_scale_base_opt = BaseOptimizer

        # Generate a fake class with name "LossScaleOptimizer"
        # Essential to avoid modifying the optimizer original class

        base_opt_class_dict = dict(opt.__class__.__dict__)
        base_opt_class_dict.update(dict(LossScaleOptimizer.__dict__))

        del base_opt_class_dict["__call__"]
        del base_opt_class_dict["__dict__"]
        del base_opt_class_dict["__weakref__"]

        opt.__class__ = type(LossScaleOptimizer.__name__,
                             (opt.loss_scale_base_opt._class, ),
                             base_opt_class_dict)

        return opt
示例#4
0
  def __init__(self, opt, loss_scale):
    if not isinstance(opt, optimizer.Optimizer):
      raise ValueError('"opt" must be an instance of Optimizer, but got: %s' %
                       type(opt))
    self._optimizer = opt

    use_locking = opt._use_locking  # pylint: disable=protected-access
    name = opt.get_name()
    super(MixedPrecisionLossScaleOptimizer, self).__init__(use_locking, name)

    self._loss_scale = loss_scale_module.get(loss_scale)
    self._track_trackable(self._optimizer, 'base_optimizer')
    self._track_trackable(self._loss_scale, 'loss_scale')
    def __init__(self, opt, loss_scale):
        if not isinstance(opt, optimizer.Optimizer):
            raise ValueError(
                '"opt" must be an instance of Optimizer, but got: %s' %
                type(opt))
        if opt.doing_loss_scaling():
            raise RuntimeError('"opt" already configured for loss scaling via '
                               'environment variable. Please use only one of '
                               'TF_ENABLE_AUTO_MIXED_PRECISION or '
                               'enable_mixed_precision_graph_rewrite().')
        self._optimizer = opt

        use_locking = opt._use_locking  # pylint: disable=protected-access
        name = opt.get_name()
        super(MixedPrecisionLossScaleOptimizer,
              self).__init__(use_locking, name)

        self._loss_scale = loss_scale_module.get(loss_scale)
        self._track_trackable(self._optimizer, 'base_optimizer')
        self._track_trackable(self._loss_scale, 'loss_scale')
    def __init__(self, opt, loss_scale):
        """Initializes this loss scale optimizer.

    Args:
      opt: The Optimizer instance to wrap.
      loss_scale: The loss scale to scale the loss and gradients. This can
        either be an int/float to use a fixed loss scale, the string "dynamic"
        to use dynamic loss scaling, or an instance of a LossScale. The string
        "dynamic" equivalent to passing `DynamicLossScale()`, and passing an
        int/float is equivalent to passing a FixedLossScale with the given loss
        scale.
    """
        if not isinstance(opt, optimizer_v2.OptimizerV2):
            raise ValueError(
                '"opt" must be an instance of OptimizerV2, but got: %s' % opt)
        if hasattr(opt, 'clipnorm'):
            raise ValueError(
                'LossScaleOptimizer does not support wrapping '
                'optimizers with a clipnorm. Optimizer %s has clipnorm '
                '%s' % (opt, opt.clipnorm))

        if hasattr(opt, 'clipvalue'):
            raise ValueError('LossScaleOptimizer does not support wrapping '
                             'optimizers with a clipvalue. Optimizer %s has '
                             'clipvalue %s' % (opt, opt.clipvalue))

        self._optimizer = opt
        self._loss_scale = loss_scale_module.get(loss_scale)
        for weight in loss_scale_module.get_loss_scale_weights(
                self._loss_scale):
            # We cannot call `track_variable` in the LossScale class itself, because a
            # file outside of Keras cannot depend on a Keras file. Calling it here
            # instead is OK, because a variable only needs to be tracked if used with
            # a Keras class, and the only way to use LossScale with a Keras class is
            # through the LossScaleOptimizer.
            backend.track_variable(weight)
        self._track_trackable(self._optimizer, 'base_optimizer')
        self._track_trackable(self._loss_scale, 'loss_scale')

        # Needed because the superclass's __getattribute__ checks this.
        self._hyper = {}
示例#7
0
 def test_serialization(self):
   loss_scale = loss_scale_module.get(123)
   config = loss_scale.get_config()
   loss_scale = loss_scale_module.FixedLossScale.from_config(config)
   self.assertEqual(self.evaluate(loss_scale()), 123.)
示例#8
0
 def test_get(self):
   scalar = loss_scale_module.get('dynamic')
   scalar2 = loss_scale_module.DynamicLossScale()
   self.assertEqual(scalar.initial_loss_scale, scalar2.initial_loss_scale)
   self.assertEqual(scalar.increment_period, scalar2.increment_period)
   self.assertEqual(scalar.multiplier, scalar2.multiplier)
示例#9
0
 def test_serialization(self):
   loss_scale = loss_scale_module.get(123)
   config = loss_scale.get_config()
   loss_scale = loss_scale_module.FixedLossScale.from_config(config)
   self.assertEqual(self.evaluate(loss_scale()), 123.)
示例#10
0
 def test_get(self):
   scalar = loss_scale_module.get('dynamic')
   scalar2 = loss_scale_module.DynamicLossScale()
   self.assertEqual(scalar.initial_loss_scale, scalar2.initial_loss_scale)
   self.assertEqual(scalar.increment_period, scalar2.increment_period)
   self.assertEqual(scalar.multiplier, scalar2.multiplier)
示例#11
0
def get(identifier):
    if isinstance(identifier, dict):
        return deserialize(identifier)
    return loss_scale_module.get(identifier)