Example #1
0
    def test_get_config(self, output_dim, initializer, scale, trainable):
        rff_layer = kernel_layers.RandomFourierFeatures(
            output_dim,
            initializer,
            scale=scale,
            trainable=trainable,
            name='random_fourier_features',
        )
        expected_initializer = initializer
        if not isinstance(initializer, six.string_types):
            expected_initializer = initializers.serialize(initializer)

        expected_dtype = ('float32'
                          if base_layer_utils.v2_dtype_behavior_enabled() else
                          None)
        expected_config = {
            'output_dim': output_dim,
            'kernel_initializer': expected_initializer,
            'scale': scale,
            'name': 'random_fourier_features',
            'trainable': trainable,
            'dtype': expected_dtype,
        }
        self.assertLen(expected_config, len(rff_layer.get_config()))
        self.assertSameElements(list(expected_config.items()),
                                list(rff_layer.get_config().items()))
Example #2
0
    def test_get_config(self, output_dim, initializer, scale, trainable):
        rff_layer = kernel_layers.RandomFourierFeatures(
            output_dim,
            initializer,
            scale=scale,
            trainable=trainable,
            name="random_fourier_features",
        )
        expected_initializer = initializer
        if not isinstance(initializer, str):
            expected_initializer = initializers.serialize(initializer)

        expected_dtype = ("float32"
                          if base_layer_utils.v2_dtype_behavior_enabled() else
                          None)
        expected_config = {
            "output_dim": output_dim,
            "kernel_initializer": expected_initializer,
            "scale": scale,
            "name": "random_fourier_features",
            "trainable": trainable,
            "dtype": expected_dtype,
        }
        self.assertLen(expected_config, len(rff_layer.get_config()))
        self.assertSameElements(list(expected_config.items()),
                                list(rff_layer.get_config().items()))
Example #3
0
def global_policy():
    """Returns the global dtype policy.

  The global policy is the default `tf.keras.mixed_precision.Policy` used for
  layers, if no policy is passed to the layer constructor. If no policy has been
  set with `keras.mixed_precision.set_global_policy`, this will return a policy
  constructed from `tf.keras.backend.floatx()` (floatx defaults to float32).

  >>> tf.keras.mixed_precision.global_policy()
  <Policy "float32">
  >>> tf.keras.layers.Dense(10).dtype_policy  # Defaults to the global policy
  <Policy "float32">

  If TensorFlow 2 behavior has been disabled with
  `tf.compat.v1.disable_v2_behavior()`, this will instead return a special
  "_infer" policy which infers the dtype from the dtype of the first input the
  first time the layer is called. This behavior matches the behavior that
  existed in TensorFlow 1.

  See `tf.keras.mixed_precision.Policy` for more information on policies.

  Returns:
    The global Policy.
  """
    if _global_policy is None:
        if base_layer_utils.v2_dtype_behavior_enabled():
            return Policy(backend.floatx())
        else:
            return Policy('_infer')
    return _global_policy
Example #4
0
 def __init__(self, name=None, dtype=None, **kwargs):
     super().__init__(name=name, dtype=dtype, **kwargs)
     self.stateful = True  # All metric layers are stateful.
     self.built = True
     if not base_layer_utils.v2_dtype_behavior_enabled():
         # We only do this when the V2 behavior is not enabled, as when it is
         # enabled, the dtype already defaults to floatx.
         self._dtype = (backend.floatx()
                        if dtype is None else tf.as_dtype(dtype).name)
Example #5
0
def set_policy(policy):
    """Sets the global dtype policy.

  The global policy is the default `tf.keras.mixed_precision.Policy` used for
  layers, if no policy is passed to the layer constructor.

  >>> tf.keras.mixed_precision.set_global_policy('mixed_float16')
  >>> tf.keras.mixed_precision.global_policy()
  <Policy "mixed_float16">
  >>> tf.keras.layers.Dense(10).dtype_policy
  <Policy "mixed_float16">
  >>> # Global policy is not used if a policy is directly passed to constructor
  >>> tf.keras.layers.Dense(10, dtype='float64').dtype_policy
  <Policy "float64">
  >>> tf.keras.mixed_precision.set_global_policy('float32')

  If no global policy is set, layers will instead default to a Policy
  constructed from `tf.keras.backend.floatx()`.

  To use mixed precision, the global policy should be set to `'mixed_float16'`
  or `'mixed_bfloat16'`, so that every layer uses a 16-bit compute dtype and
  float32 variable dtype by default.

  Only floating point policies can be set as the global policy, such as
  `'float32'` and `'mixed_float16'`. Non-floating point policies such as
  `'int32'` and `'complex64'` cannot be set as the global policy because most
  layers do not support such policies.

  See `tf.keras.mixed_precision.Policy` for more information.

  Args:
    policy: A Policy, or a string that will be converted to a Policy. Can also
      be None, in which case the global policy will be constructed from
      `tf.keras.backend.floatx()`
  """
    global _global_policy
    if not base_layer_utils.v2_dtype_behavior_enabled():
        raise ValueError(
            'The global policy can only be set in TensorFlow 2 or if '
            'V2 dtype behavior has been set. To enable V2 dtype '
            'behavior, call '
            '"tf.compat.v1.keras.layers.enable_v2_dtype_behavior()"')
    if policy is not None and not isinstance(policy, Policy):
        policy = Policy(policy)
    is_mixed_policy = (policy is not None
                       and policy.compute_dtype != policy.variable_dtype)
    if is_mixed_policy:
        _check_if_mixed_precision_graph_rewrite_is_enabled(policy)
    if (policy is not None and policy.compute_dtype is not None
            and not tf.as_dtype(policy.compute_dtype).is_floating):
        raise ValueError(
            'set_policy can only be used to set the global policy to '
            'floating-point policies, such as "float32" and '
            '"mixed_float16", but got policy: %s' % (policy.name, ))
    _global_policy = policy
    mixed_precision_global_state.using_mixed_precision_policy = is_mixed_policy
Example #6
0
 def test_policy_scope(self):
     if base_layer_utils.v2_dtype_behavior_enabled():
         default_policy = 'float32'
     else:
         default_policy = '_infer'
     with mp_policy.policy_scope('mixed_float16'):
         self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
         with mp_policy.policy_scope('_infer'):
             self.assertEqual(mp_policy.global_policy().name, '_infer')
         self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
     self.assertEqual(mp_policy.global_policy().name, default_policy)
Example #7
0
 def test_policy_scope(self):
     if base_layer_utils.v2_dtype_behavior_enabled():
         default_policy = "float32"
     else:
         default_policy = "_infer"
     with mp_policy.policy_scope("mixed_float16"):
         self.assertEqual(mp_policy.global_policy().name, "mixed_float16")
         with mp_policy.policy_scope("_infer"):
             self.assertEqual(mp_policy.global_policy().name, "_infer")
         self.assertEqual(mp_policy.global_policy().name, "mixed_float16")
     self.assertEqual(mp_policy.global_policy().name, default_policy)
Example #8
0
    def __init__(
        self,
        input_dim,
        output_dim,
        embeddings_initializer="uniform",
        embeddings_regularizer=None,
        activity_regularizer=None,
        embeddings_constraint=None,
        mask_zero=False,
        input_length=None,
        **kwargs,
    ):
        if "input_shape" not in kwargs:
            if input_length:
                kwargs["input_shape"] = (input_length,)
            else:
                kwargs["input_shape"] = (None,)
        if input_dim <= 0 or output_dim <= 0:
            raise ValueError(
                "Both `input_dim` and `output_dim` should be positive, "
                f"Received input_dim = {input_dim} "
                f"and output_dim = {output_dim}"
            )
        if (
            not base_layer_utils.v2_dtype_behavior_enabled()
            and "dtype" not in kwargs
        ):
            # In TF1, the dtype defaults to the input dtype which is typically
            # int32, so explicitly set it to floatx
            kwargs["dtype"] = backend.floatx()
        # We set autocast to False, as we do not want to cast floating- point
        # inputs to self.dtype. In call(), we cast to int32, and casting to
        # self.dtype before casting to int32 might cause the int32 values to be
        # different due to a loss of precision.
        kwargs["autocast"] = False
        super().__init__(**kwargs)

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.embeddings_initializer = initializers.get(embeddings_initializer)
        self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.embeddings_constraint = constraints.get(embeddings_constraint)
        self.mask_zero = mask_zero
        self.supports_masking = mask_zero
        self.input_length = input_length
Example #9
0
 def test_global_policy(self):
   if base_layer_utils.v2_dtype_behavior_enabled():
     default_policy = 'float32'
   else:
     default_policy = '_infer'
   self.assertEqual(mp_policy.global_policy().name, default_policy)
   try:
     mp_policy.set_global_policy('mixed_float16')
     self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
     with tf.Graph().as_default():  # Policies are not associated with a graph
       self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
     mp_policy.set_global_policy('_infer')
     self.assertEqual(mp_policy.global_policy().name, '_infer')
     policy = mp_policy.Policy('mixed_bfloat16')
     mp_policy.set_global_policy(policy)
     self.assertIs(mp_policy.global_policy(), policy)
   finally:
     mp_policy.set_global_policy(None)
Example #10
0
    def __init__(self,
                 input_dim,
                 output_dim,
                 embeddings_initializer='uniform',
                 embeddings_regularizer=None,
                 activity_regularizer=None,
                 embeddings_constraint=None,
                 mask_zero=False,
                 input_length=None,
                 **kwargs):
        if 'input_shape' not in kwargs:
            if input_length:
                kwargs['input_shape'] = (input_length, )
            else:
                kwargs['input_shape'] = (None, )
        if input_dim <= 0 or output_dim <= 0:
            raise ValueError(
                'Both `input_dim` and `output_dim` should be positive, '
                'found input_dim {} and output_dim {}'.format(
                    input_dim, output_dim))
        if (not base_layer_utils.v2_dtype_behavior_enabled()
                and 'dtype' not in kwargs):
            # In TF1, the dtype defaults to the input dtype which is typically int32,
            # so explicitly set it to floatx
            kwargs['dtype'] = K.floatx()
        # We set autocast to False, as we do not want to cast floating- point inputs
        # to self.dtype. In call(), we cast to int32, and casting to self.dtype
        # before casting to int32 might cause the int32 values to be different due
        # to a loss of precision.
        kwargs['autocast'] = False
        super(Embedding, self).__init__(**kwargs)

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.embeddings_initializer = initializers.get(embeddings_initializer)
        self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.embeddings_constraint = constraints.get(embeddings_constraint)
        self.mask_zero = mask_zero
        self.supports_masking = mask_zero
        self.input_length = input_length
Example #11
0
 def test_global_policy(self):
     if base_layer_utils.v2_dtype_behavior_enabled():
         default_policy = "float32"
     else:
         default_policy = "_infer"
     self.assertEqual(mp_policy.global_policy().name, default_policy)
     try:
         mp_policy.set_global_policy("mixed_float16")
         self.assertEqual(mp_policy.global_policy().name, "mixed_float16")
         # Policies are not associated with a graph
         with tf.Graph().as_default():
             self.assertEqual(
                 mp_policy.global_policy().name, "mixed_float16"
             )
         mp_policy.set_global_policy("_infer")
         self.assertEqual(mp_policy.global_policy().name, "_infer")
         policy = mp_policy.Policy("mixed_bfloat16")
         mp_policy.set_global_policy(policy)
         self.assertIs(mp_policy.global_policy(), policy)
     finally:
         mp_policy.set_global_policy(None)