def __init__(self, name): """Constructs the policy. The `name` argument determines the compute and variable dtype. The compute and variable dtypes can only be specified through `name`, and cannot be specified directly. `name` is also used by `tf.keras.Model.compile`. If `name` is `"mixed_float16"`, `tf.keras.Model.compile` will automatically wrap the optimizer with a LossScaleOptimizer if it is not already a LossScaleOptimizer. Args: name: A string. Can be one of the following values: * Any dtype name, such as 'float32' or 'float64'. Both the variable and compute dtypes will be that dtype. * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or bfloat16, while the variable dtype is float32. With 'mixed_float16', `tf.keras.Model.compile` will wrap the optimizer with a `tf.keras.mixed_precision.LossScaleOptimizer. These policies are used for mixed precision training. """ if isinstance(name, dtypes.DType): raise TypeError("'name' must be a string, not a DType. " "Instead, pass DType.name. Got: %s" % (name.name, )) elif not isinstance(name, six.string_types): raise TypeError("'name' must be a string, but got: %s" % (name, )) self._name = name self._compute_dtype, self._variable_dtype = self._parse_name(name) if name in ('mixed_float16', 'mixed_bloat16'): device_compatibility_check.log_device_compatibility_check(name)
def __init__(self, name, loss_scale=USE_DEFAULT): """Constructs the policy. The `name` argument determines the compute and variable dtype, the default loss scale, and has no additional effect on the Policy. The compute and variable dtypes can only be specified through `name`, and cannot be specified directly. Args: name: A string. Can be one of the following values: * Any dtype name, such as 'float32' or 'float64'. Both the variable and compute dtypes will be that dtype. * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or bfloat16, while the variable dtype is float32. With 'mixed_float16', a dynamic loss scale is used. These policies are used for mixed precision training. * 'infer' (deprecated): Infer the compute and variable dtype from the input dtype. loss_scale: A `tf.mixed_precision.experimental.LossScale`, an int (which uses a `FixedLossScale`), or the string "dynamic" (which uses a `DynamicLossScale`). Defaults to using no loss scaling unless `name` is "mixed_float16", in which case this defaults to "dynamic". Only `tf.keras.Model`s, not layers, use the loss scale, and it is only used during `Model.fit`, `Model.train_on_batch`, and other similar methods. """ if isinstance(name, dtypes.DType): raise TypeError("'name' must be a string, not a DType. " "Instead, pass DType.name. Got: %s" % (name.name, )) elif not isinstance(name, six.string_types): raise TypeError("'name' must be a string, but got: %s" % (name, )) self._name = name self._compute_dtype, self._variable_dtype = self._parse_name(name) if loss_scale == USE_DEFAULT: loss_scale = 'dynamic' if name == 'mixed_float16' else None self._using_default_loss_scale = True else: self._using_default_loss_scale = False if loss_scale and self._compute_dtype not in (None, 'float16'): tf_logging.warn( 'Creating a Policy with a loss scale is only useful for ' 'float16 policies. You passed loss_scale=%r for policy ' '%s. Consider not passing any loss_scale instead.' % (loss_scale, name)) self._loss_scale = keras_loss_scale_module.get(loss_scale) if name in ('mixed_float16', 'mixed_bloat16'): device_compatibility_check.log_device_compatibility_check( name, skip_local=True)