Esempio n. 1
0
    def __init__(self, name):
        """Constructs the policy.

    The `name` argument determines the compute and variable dtype. The compute
    and variable dtypes can only be specified through `name`, and cannot be
    specified directly.

    `name` is also used by `tf.keras.Model.compile`. If `name` is
    `"mixed_float16"`, `tf.keras.Model.compile` will automatically wrap the
    optimizer with a LossScaleOptimizer if it is not already a
    LossScaleOptimizer.

    Args:
      name: A string. Can be one of the following values:
        * Any dtype name, such as 'float32' or 'float64'. Both the variable and
          compute dtypes will be that dtype.
        * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
          bfloat16, while the variable dtype is float32. With 'mixed_float16',
          `tf.keras.Model.compile` will wrap the optimizer with a
          `tf.keras.mixed_precision.LossScaleOptimizer. These policies are used
          for mixed precision training.
    """
        if isinstance(name, dtypes.DType):
            raise TypeError("'name' must be a string, not a DType. "
                            "Instead, pass DType.name. Got: %s" %
                            (name.name, ))
        elif not isinstance(name, six.string_types):
            raise TypeError("'name' must be a string, but got: %s" % (name, ))
        self._name = name
        self._compute_dtype, self._variable_dtype = self._parse_name(name)
        if name in ('mixed_float16', 'mixed_bloat16'):
            device_compatibility_check.log_device_compatibility_check(name)
Esempio n. 2
0
    def __init__(self, name, loss_scale=USE_DEFAULT):
        """Constructs the policy.

    The `name` argument determines the compute and variable dtype, the default
    loss scale, and has no additional effect on the Policy. The compute and
    variable dtypes can only be specified through `name`, and cannot be
    specified directly.

    Args:
      name: A string. Can be one of the following values:
        * Any dtype name, such as 'float32' or 'float64'. Both the variable and
          compute dtypes will be that dtype.
        * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
          bfloat16, while the variable dtype is float32. With 'mixed_float16',
          a dynamic loss scale is used. These policies are used for mixed
          precision training.
        * 'infer' (deprecated): Infer the compute and variable dtype from the
          input dtype.
      loss_scale: A `tf.mixed_precision.experimental.LossScale`, an int (which
      uses a `FixedLossScale`), or the string "dynamic" (which uses a
      `DynamicLossScale`). Defaults to using no loss scaling unless `name` is
      "mixed_float16", in which case this defaults to "dynamic". Only
      `tf.keras.Model`s, not layers, use the loss scale, and it is only used
      during `Model.fit`, `Model.train_on_batch`, and other similar methods.
    """
        if isinstance(name, dtypes.DType):
            raise TypeError("'name' must be a string, not a DType. "
                            "Instead, pass DType.name. Got: %s" %
                            (name.name, ))
        elif not isinstance(name, six.string_types):
            raise TypeError("'name' must be a string, but got: %s" % (name, ))
        self._name = name
        self._compute_dtype, self._variable_dtype = self._parse_name(name)

        if loss_scale == USE_DEFAULT:
            loss_scale = 'dynamic' if name == 'mixed_float16' else None
            self._using_default_loss_scale = True
        else:
            self._using_default_loss_scale = False
        if loss_scale and self._compute_dtype not in (None, 'float16'):
            tf_logging.warn(
                'Creating a Policy with a loss scale is only useful for '
                'float16 policies. You passed loss_scale=%r for policy '
                '%s. Consider not passing any loss_scale instead.' %
                (loss_scale, name))
        self._loss_scale = keras_loss_scale_module.get(loss_scale)

        if name in ('mixed_float16', 'mixed_bloat16'):
            device_compatibility_check.log_device_compatibility_check(
                name, skip_local=True)