Exemple #1
0
  def test_get_config(self, output_dim, initializer, scale, trainable):
    rff_layer = kernel_layers.RandomFourierFeatures(
        output_dim,
        initializer,
        scale=scale,
        trainable=trainable,
        name='random_fourier_features',
    )
    expected_initializer = initializer
    if isinstance(initializer, init_ops.Initializer):
      expected_initializer = initializers.serialize(initializer)

    expected_dtype = (
        'float32' if base_layer_utils.v2_dtype_behavior_enabled() else None)
    expected_config = {
        'output_dim': output_dim,
        'kernel_initializer': expected_initializer,
        'scale': scale,
        'name': 'random_fourier_features',
        'trainable': trainable,
        'dtype': expected_dtype,
    }
    self.assertLen(expected_config, len(rff_layer.get_config()))
    self.assertSameElements(
        list(expected_config.items()), list(rff_layer.get_config().items()))
Exemple #2
0
def set_policy(policy):
    """Sets the global Policy.

  The global policy is the default policy used for layers, if no policy is
  passed to the layer constructor. If no global policy is set, layers will
  instead default to a Policy constructed from `tf.keras.backend.floatx()`.

  See `keras.mixed_precision.experimental.Policy` for more information.

  Args:
    policy: A Policy, or a string that will be converted to a Policy..
  """
    global _global_policy
    if not base_layer_utils.v2_dtype_behavior_enabled():
        raise ValueError(
            'The global policy can only be set in TensorFlow 2 or if '
            'V2 dtype behavior has been set. To enable V2 dtype '
            'behavior, call '
            '"tf.compat.v1.keras.layers.enable_v2_dtype_behavior()"')
    if policy is not None and not isinstance(policy, Policy):
        policy = Policy(policy)
    is_mixed_policy = policy is not None and policy.should_cast_variables
    if is_mixed_policy:
        _check_if_mixed_precision_graph_rewrite_is_enabled(policy)
    _global_policy = policy
    mixed_precision_global_state.using_mixed_precision_policy = is_mixed_policy
Exemple #3
0
def global_policy():
    """Returns the global dtype policy.

  The global policy is the default `tf.keras.mixed_precision.Policy` used for
  layers, if no policy is passed to the layer constructor. If no policy has been
  set with `keras.mixed_precision.set_global_policy`, this will return a policy
  constructed from `tf.keras.backend.floatx()` (floatx defaults to float32).

  >>> tf.keras.mixed_precision.global_policy()
  <Policy "float32">
  >>> tf.keras.layers.Dense(10).dtype_policy  # Defaults to the global policy
  <Policy "float32">

  If TensorFlow 2 behavior has been disabled with
  `tf.compat.v1.disable_v2_behavior()`, this will instead return a special
  "_infer" policy which infers the dtype from the dtype of the first input the
  first time the layer is called. This behavior matches the behavior that
  existed in TensorFlow 1.

  See `tf.keras.mixed_precision.Policy` for more information on policies.

  Returns:
    The global Policy.
  """
    if _global_policy is None:
        if base_layer_utils.v2_dtype_behavior_enabled():
            return Policy(backend.floatx())
        else:
            return Policy('_infer')
    return _global_policy
Exemple #4
0
def set_policy(policy):
    """Sets the global Policy.

  The global policy is the default policy used for layers, if no policy is
  passed to the layer constructor. If no global policy is set, layers will
  instead default to a Policy constructed from `tf.keras.backend.floatx()` in
  TensorFlow 2. In TensorFlow 1, layers default to an "infer" policy.

  See `keras.mixed_precision.experimental.Policy` for more information.

  Args:
    policy: A Policy, or a string that will be converted to a Policy..
  """
    global _global_policy
    _check_if_mixed_precision_graph_rewrite_is_enabled()
    if policy is not None and not isinstance(policy, Policy):
        policy = Policy(policy)
    if (policy and not base_layer_utils.v2_dtype_behavior_enabled()
            and policy.compute_dtype):
        raise ValueError(
            'The global policy can only be set to a non-infer policy in TensorFlow '
            '2')
    _global_policy = policy
    mixed_precision_global_state.using_default_mixed_precision_policy = (
        _global_policy is None)
Exemple #5
0
def set_policy(policy):
    """Sets the global Policy.

  The global policy is the default policy used for layers, if no policy is
  passed to the layer constructor. If no global policy is set, layers will
  instead default to the "infer" policy.

  See `keras.mixed_precision.experimental.Policy` for more information.

  Args:
    policy: A Policy, or a string that will be converted to a Policy..
  """
    global _global_policy
    _check_if_mixed_precision_graph_rewrite_is_enabled()
    if policy is not None and not isinstance(policy, Policy):
        policy = Policy(policy)
    if (policy and not base_layer_utils.v2_dtype_behavior_enabled()
            and policy.compute_dtype):
        raise ValueError(
            'When a global Policy is set to a non-infer policy, the V2 layer dtype '
            'behavior must be enabled. V2 layer dtype behavior will soon be turned '
            'on by default, so please wait.')
    _global_policy = policy
    mixed_precision_global_state.using_default_mixed_precision_policy = (
        _global_policy is None)
Exemple #6
0
def set_policy(policy):
    """Sets the global dtype policy.

  The global policy is the default `tf.keras.mixed_precision.Policy` used for
  layers, if no policy is passed to the layer constructor.

  >>> tf.keras.mixed_precision.set_global_policy('mixed_float16')
  >>> tf.keras.mixed_precision.global_policy()
  <Policy "mixed_float16">
  >>> tf.keras.layers.Dense(10).dtype_policy
  <Policy "mixed_float16">
  >>> # Global policy is not used if a policy is directly passed to constructor
  >>> tf.keras.layers.Dense(10, dtype='float64').dtype_policy
  <Policy "float64">
  >>> tf.keras.mixed_precision.set_global_policy('float32')

  If no global policy is set, layers will instead default to a Policy
  constructed from `tf.keras.backend.floatx()`.

  To use mixed precision, the global policy should be set to `'mixed_float16'`
  or `'mixed_bfloat16'`, so that every layer uses a 16-bit compute dtype and
  float32 variable dtype by default.

  Only floating point policies can be set as the global policy, such as
  `'float32'` and `'mixed_float16'`. Non-floating point policies such as
  `'int32'` and `'complex64'` cannot be set as the global policy because most
  layers do not support such policies.

  See `tf.keras.mixed_precision.Policy` for more information.

  Args:
    policy: A Policy, or a string that will be converted to a Policy. Can also
      be None, in which case the global policy will be constructed from
      `tf.keras.backend.floatx()`
  """
    global _global_policy
    if not base_layer_utils.v2_dtype_behavior_enabled():
        raise ValueError(
            'The global policy can only be set in TensorFlow 2 or if '
            'V2 dtype behavior has been set. To enable V2 dtype '
            'behavior, call '
            '"tf.compat.v1.keras.layers.enable_v2_dtype_behavior()"')
    if policy is not None and not isinstance(policy, Policy):
        policy = Policy(policy)
    is_mixed_policy = (policy is not None
                       and policy.compute_dtype != policy.variable_dtype)
    if is_mixed_policy:
        _check_if_mixed_precision_graph_rewrite_is_enabled(policy)
    if (policy is not None and policy.compute_dtype is not None
            and not dtypes.as_dtype(policy.compute_dtype).is_floating):
        raise ValueError(
            'set_policy can only be used to set the global policy to '
            'floating-point policies, such as "float32" and '
            '"mixed_float16", but got policy: %s' % (policy.name, ))
    _global_policy = policy
    mixed_precision_global_state.set_using_mixed_precision_policy(
        is_mixed_policy)
Exemple #7
0
 def test_policy_scope(self):
   if base_layer_utils.v2_dtype_behavior_enabled():
     default_policy = 'float32'
   else:
     default_policy = '_infer'
   with mp_policy.policy_scope('mixed_float16'):
     self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
     with mp_policy.policy_scope('_infer'):
       self.assertEqual(mp_policy.global_policy().name, '_infer')
     self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
   self.assertEqual(mp_policy.global_policy().name, default_policy)
Exemple #8
0
 def test_global_policy(self):
   if base_layer_utils.v2_dtype_behavior_enabled():
     default_policy = 'float32'
   else:
     default_policy = '_infer'
   self.assertEqual(mp_policy.global_policy().name, default_policy)
   try:
     mp_policy.set_policy('mixed_float16')
     self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
     with ops.Graph().as_default():  # Policies are not associated with a graph
       self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
     mp_policy.set_policy('_infer')
     self.assertEqual(mp_policy.global_policy().name, '_infer')
     policy = mp_policy.Policy('mixed_bfloat16')
     mp_policy.set_policy(policy)
     self.assertIs(mp_policy.global_policy(), policy)
   finally:
     mp_policy.set_policy(None)
Exemple #9
0
  def test_infer_with_float32_vars(self, strategy_fn):
    x = constant_op.constant([1.], dtype=dtypes.float16)
    with strategy_fn().scope(), policy.policy_scope('infer_float32_vars'):
      layer = AddLayer(assert_type=dtypes.float16)
      self.assertEqual(layer.dtype, dtypes.float32)
      y = layer(x)
      self.assertEqual(layer.v.dtype, dtypes.float32)
      self.assertEqual(y.dtype, dtypes.float16)
      self.assertEqual(layer.dtype, dtypes.float32)
      self.assertEqual(layer._dtype_policy._name, 'float16_with_float32_vars')
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(self.evaluate(y), 2.)

      if base_layer_utils.v2_dtype_behavior_enabled():
        # Layer should now cast inputs to float16
        x = constant_op.constant([1.], dtype=dtypes.float32)
        y = layer(x)
        self.assertEqual(y.dtype, dtypes.float16)
Exemple #10
0
def global_policy():
    """Returns the global Policy.

  The global policy is the default policy used for layers, if no policy is
  passed to the layer constructor. If no policy has been set with
  `keras.mixed_precision.experimental.set_policy`, this will return the "infer"
  policy.

  See `keras.mixed_precision.experimental.Policy` for more information.

  Returns:
    The global Policy.
  """
    if _global_policy is None:
        if base_layer_utils.v2_dtype_behavior_enabled():
            return Policy(backend.floatx())
        else:
            return Policy('infer')
    return _global_policy
Exemple #11
0
    def __init__(self,
                 input_dim,
                 output_dim,
                 embeddings_initializer='uniform',
                 embeddings_regularizer=None,
                 activity_regularizer=None,
                 embeddings_constraint=None,
                 mask_zero=False,
                 input_length=None,
                 **kwargs):
        if 'input_shape' not in kwargs:
            if input_length:
                kwargs['input_shape'] = (input_length, )
            else:
                kwargs['input_shape'] = (None, )
        if input_dim <= 0 or output_dim <= 0:
            raise ValueError(
                'Both `input_dim` and `output_dim` should be positive, '
                'found input_dim {} and output_dim {}'.format(
                    input_dim, output_dim))
        if (not base_layer_utils.v2_dtype_behavior_enabled()
                and 'dtype' not in kwargs):
            # In TF1, the dtype defaults to the input dtype which is typically int32,
            # so explicitly set it to floatx
            kwargs['dtype'] = K.floatx()
        # We set autocast to False, as we do not want to cast floating- point inputs
        # to self.dtype. In call(), we cast to int32, and casting to self.dtype
        # before casting to int32 might cause the int32 values to be different due
        # to a loss of precision.
        kwargs['autocast'] = False
        super(Embedding, self).__init__(**kwargs)

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.embeddings_initializer = initializers.get(embeddings_initializer)
        self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.embeddings_constraint = constraints.get(embeddings_constraint)
        self.mask_zero = mask_zero
        self.supports_masking = mask_zero
        self.input_length = input_length
Exemple #12
0
def set_policy(policy):
    """Sets the global Policy.

  The global policy is the default policy used for layers, if no policy is
  passed to the layer constructor. If no global policy is set, layers will
  instead default to a Policy constructed from `tf.keras.backend.floatx()`.

  Only floating point policies can be set as the global policy, such as
  `'float32'` and `'mixed_float16'`. Non-floating point policies such as
  `'int32'` and `'complex64'` cannot be set as the global policy because most
  layers do not support such policies.

  See `tf.keras.mixed_precision.Policy` for more information.

  Args:
    policy: A Policy, or a string that will be converted to a Policy..
  """
    global _global_policy
    if not base_layer_utils.v2_dtype_behavior_enabled():
        raise ValueError(
            'The global policy can only be set in TensorFlow 2 or if '
            'V2 dtype behavior has been set. To enable V2 dtype '
            'behavior, call '
            '"tf.compat.v1.keras.layers.enable_v2_dtype_behavior()"')
    if policy is not None and not isinstance(policy, Policy):
        policy = Policy(policy)
    is_mixed_policy = (policy is not None
                       and policy.compute_dtype != policy.variable_dtype)
    if is_mixed_policy:
        _check_if_mixed_precision_graph_rewrite_is_enabled(policy)
    if (policy is not None and policy.compute_dtype is not None
            and not dtypes.as_dtype(policy.compute_dtype).is_floating):
        raise ValueError(
            'set_policy can only be used to set the global policy to '
            'floating-point policies, such as "float32" and '
            '"mixed_float16", but got policy: %s' % (policy.name, ))
    _global_policy = policy
    mixed_precision_global_state.using_mixed_precision_policy = is_mixed_policy
Exemple #13
0
def policy_defaults_to_floatx():
    """Returns True if `global_policy()` will use the current value of floatx."""
    return _global_policy is None and base_layer_utils.v2_dtype_behavior_enabled(
    )