Beispiel #1
0
  def test_serialization(self):
    # Test policies that are equivalent to a single dtype
    for policy_name in 'float16', 'float32', 'int8', 'string', 'bool':
      policy = mp_policy.Policy(policy_name)
      config = mp_policy.serialize(policy)
      self.assertEqual(config, policy_name)
      new_policy = mp_policy.deserialize(config)
      self.assertEqual(str(policy), str(new_policy))

    # Test "_infer" policy
    policy = mp_policy.Policy('_infer')
    config = mp_policy.serialize(policy)
    self.assertIsNone(config)
    new_policy = mp_policy.deserialize(config)
    self.assertEqual(str(policy), str(new_policy))

    class MyPolicy(mp_policy.Policy):
      pass

    # Test policies that are not equivalent to a single dtype
    for policy in (
        mp_policy.Policy('mixed_float16'),
        mp_policy.Policy('mixed_bfloat16'),
        MyPolicy('float32')
    ):
      config = mp_policy.serialize(policy)
      self.assertEqual(config, {'class_name': policy.__class__.__name__,
                                'config': {'name': policy.name}})
      new_policy = mp_policy.deserialize(config,
                                         custom_objects={'MyPolicy': MyPolicy})
      self.assertEqual(str(policy), str(new_policy))
Beispiel #2
0
    def test_device_compatibility_warning(self):
        if not tf.executing_eagerly():
            self.skipTest("Run in eager mode only.")

        device_compatibility_check._logged_compatibility_check = False
        with tf.compat.v1.test.mock.patch.object(
            tf_logging, "warning"
        ) as mock_warn:
            mp_policy.Policy("mixed_float16")
        if tf.config.list_physical_devices("GPU"):
            mock_warn.assert_not_called()
        else:
            self.assertRegex(
                mock_warn.call_args[0][0],
                r"Mixed precision compatibility check \(mixed_float16\): "
                r"WARNING.*",
            )

        if tf.config.list_physical_devices("GPU"):
            # Assert message is only logged once
            with tf.compat.v1.test.mock.patch.object(
                tf_logging, "warning"
            ) as mock_warn:
                mp_policy.Policy("mixed_float16")
            mock_warn.assert_not_called()
Beispiel #3
0
    def test_config(self, strategy_fn):
        x = tf.constant([1.0], dtype=tf.float16)
        with strategy_fn().scope():
            for layer, dtype in (
                (mp_test_util.MultiplyLayer(), "float32"),
                (mp_test_util.MultiplyLayer(dtype="float64"), "float64"),
                (
                    mp_test_util.MultiplyLayer(dtype=policy.Policy("float64")),
                    "float64",
                ),
            ):
                config = layer.get_config()
                self.assertEqual(config["dtype"], dtype)
                self.assertIsInstance(config["dtype"], str)
                layer = mp_test_util.MultiplyLayer.from_config(config)
                self.assertEqual(layer.dtype, dtype)
                self.assertEqual(layer(x).dtype, dtype)
                self.assertEqual(layer.v.dtype, dtype)

            layer = mp_test_util.MultiplyLayer(dtype="mixed_float16")
            config = layer.get_config()
            self.assertEqual(
                config["dtype"],
                {
                    "class_name": "Policy",
                    "config": {
                        "name": "mixed_float16"
                    }
                },
            )
            layer = mp_test_util.MultiplyLayer.from_config(config)
            self.assertEqual(layer.dtype, "float32")
            self.assertEqual(layer(x).dtype, "float16")
            self.assertEqual(layer.v.dtype, "float32")
            config = layer.get_config()
            self.assertEqual(
                config["dtype"],
                {
                    "class_name": "Policy",
                    "config": {
                        "name": "mixed_float16"
                    }
                },
            )

            layer = mp_test_util.MultiplyLayer(dtype=policy.Policy("_infer"))
            config = layer.get_config()
            self.assertIsNone(config["dtype"])
            layer = mp_test_util.MultiplyLayer.from_config(config)
            # If a layer is serialized with the "_infer" policy, when
            # deserialized into TF 2 it will have the global policy instead of
            # "_infer". This is because "_infer" is serialized into None, and
            # passing dtype=None in TensorFlow 2 indicates to use the global
            # policy.
            self.assertEqual(layer.dtype, "float32")
            self.assertEqual(layer(x).dtype, "float32")
            self.assertEqual(layer.v.dtype, "float32")
Beispiel #4
0
 def test_unsupported_strategy(self):
   strategy = create_central_storage_strategy()
   with strategy.scope(), self.assertRaisesRegex(
       ValueError, 'Mixed precision is not supported with the '
       'tf.distribute.Strategy: CentralStorageStrategy. Either '
       'stop using mixed precision by removing the use of the '
       '"mixed_float16" policy or use a different Strategy, e.g. '
       'a MirroredStrategy.'):
     mp_test_util.MultiplyLayer(dtype=policy.Policy('mixed_float16'))
   # Non-mixed policies are fine
   mp_test_util.MultiplyLayer(dtype=policy.Policy('float64'))
Beispiel #5
0
 def test_v1_dtype_behavior(self):
     # Setting global policies are not allowed with V1 dtype behavior
     with self.assertRaisesRegex(
             ValueError, 'global policy can only be set in TensorFlow 2'):
         with mp_policy.policy_scope(mp_policy.Policy('_infer')):
             pass
     with self.assertRaisesRegex(
             ValueError, 'global policy can only be set in TensorFlow 2'):
         with mp_policy.policy_scope(mp_policy.Policy('float32')):
             pass
     with self.assertRaisesRegex(
             ValueError, 'global policy can only be set in TensorFlow 2'):
         with mp_policy.policy_scope(mp_policy.Policy('mixed_float16')):
             pass
Beispiel #6
0
    def test_dtype_attributes(self):
        for dtype in 'int32', 'bool', 'float16', 'float32':
            policy = mp_policy.Policy(dtype)
            self.assertEqual(policy.name, dtype)
            self.assertEqual(policy.compute_dtype, dtype)
            self.assertEqual(policy.variable_dtype, dtype)

        for dtype in 'float16', 'bfloat16':
            policy = mp_policy.Policy('mixed_' + dtype)
            self.assertEqual(policy.name, 'mixed_' + dtype)
            self.assertEqual(policy.compute_dtype, dtype)
            self.assertEqual(policy.variable_dtype, 'float32')

        policy = mp_policy.Policy('_infer')
        self.assertEqual(policy.compute_dtype, None)
        self.assertEqual(policy.variable_dtype, None)
Beispiel #7
0
    def test_dtype_attributes(self):
        for dtype in "int32", "bool", "float16", "float32":
            policy = mp_policy.Policy(dtype)
            self.assertEqual(policy.name, dtype)
            self.assertEqual(policy.compute_dtype, dtype)
            self.assertEqual(policy.variable_dtype, dtype)

        for dtype in "float16", "bfloat16":
            policy = mp_policy.Policy("mixed_" + dtype)
            self.assertEqual(policy.name, "mixed_" + dtype)
            self.assertEqual(policy.compute_dtype, dtype)
            self.assertEqual(policy.variable_dtype, "float32")

        policy = mp_policy.Policy("_infer")
        self.assertEqual(policy.compute_dtype, None)
        self.assertEqual(policy.variable_dtype, None)
Beispiel #8
0
  def test_save_slot_variables_with_autocast_vars(self,
                                                  strategy_fn,
                                                  var_name='v'):
    p = policy.Policy('mixed_float16')
    with strategy_fn().scope(), policy.policy_scope(p):
      x = layers.Input(shape=(2,), batch_size=2)
      # Having a var_name other than 'v' tests that a fixed bug (b/134713714)
      # does not reoccur. The bug was that a crash would occur when saving a
      # checkpoint where an AutoCastVariable with a slot variable would have a
      # different name than the layer attribute's name (layer.v in this case).
      layer = mp_test_util.MultiplyLayer(assert_type=tf.float16,
                                         var_name=var_name)
      y = layer(x)
      model = models.Model(inputs=x, outputs=y)
      opt = gradient_descent.SGD(1., 1.)
      opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
                                                    initial_scale=1)
      model.compile(
          optimizer=opt,
          loss='mse',
          run_eagerly=testing_utils.should_run_eagerly())

    model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2)
    weights_file = os.path.join(self.get_temp_dir(), 'weights')
    model.save_weights(weights_file)
    saved_slot = backend.get_value(opt.get_slot(layer.v, 'momentum'))

    model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2)
    new_slot = backend.get_value(opt.get_slot(layer.v, 'momentum'))
    self.assertNotEqual(new_slot, saved_slot)

    model.load_weights(weights_file)
    restored_slot = backend.get_value(opt.get_slot(layer.v, 'momentum'))
    self.assertEqual(restored_slot, saved_slot)
Beispiel #9
0
    def test_config(self, strategy_fn):
        x = tf.constant([1.], dtype=tf.float16)
        with strategy_fn().scope():
            for layer, dtype in ((mp_test_util.MultiplyLayer(), 'float32'),
                                 (mp_test_util.MultiplyLayer(dtype='float64'),
                                  'float64'), (mp_test_util.MultiplyLayer(
                                      dtype=policy.Policy('float64')),
                                               'float64')):
                config = layer.get_config()
                self.assertEqual(config['dtype'], dtype)
                self.assertIsInstance(config['dtype'], str)
                layer = mp_test_util.MultiplyLayer.from_config(config)
                self.assertEqual(layer.dtype, dtype)
                self.assertEqual(layer(x).dtype, dtype)
                self.assertEqual(layer.v.dtype, dtype)

            layer = mp_test_util.MultiplyLayer(dtype='mixed_float16')
            config = layer.get_config()
            self.assertEqual(config['dtype'], {
                'class_name': 'Policy',
                'config': {
                    'name': 'mixed_float16'
                }
            })
            layer = mp_test_util.MultiplyLayer.from_config(config)
            self.assertEqual(layer.dtype, 'float32')
            self.assertEqual(layer(x).dtype, 'float16')
            self.assertEqual(layer.v.dtype, 'float32')
            config = layer.get_config()
            self.assertEqual(config['dtype'], {
                'class_name': 'Policy',
                'config': {
                    'name': 'mixed_float16'
                }
            })

            layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer'))
            config = layer.get_config()
            self.assertIsNone(config['dtype'])
            layer = mp_test_util.MultiplyLayer.from_config(config)
            # If a layer is serialized with the "_infer" policy, when deserialized
            # into TF 2 it will have the global policy instead of "_infer". This is
            # because "_infer" is serialized into None, and passing dtype=None in
            # TensorFlow 2 indicates to use the global policy.
            self.assertEqual(layer.dtype, 'float32')
            self.assertEqual(layer(x).dtype, 'float32')
            self.assertEqual(layer.v.dtype, 'float32')
Beispiel #10
0
 def test_build_and_call_layer_in_function(self):
   layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('mixed_float16'))
   @tf.function
   def f():
     return layer(1.)
   y = f()
   self.evaluate(tf.compat.v1.global_variables_initializer())
   self.assertEqual(y.dtype, 'float16')
   self.assertEqual(layer.v.dtype, 'float32')
   self.assertEqual(self.evaluate(y), 1.)
Beispiel #11
0
 def test_dense_with_policy(self):
   inputs = tf.convert_to_tensor(np.random.randint(low=0, high=7, size=(2, 2)))
   layer = keras.layers.Dense(5, dtype=policy.Policy('mixed_float16'))
   outputs = layer(inputs)
   output_signature = layer.compute_output_signature(
       tf.TensorSpec(dtype='float16', shape=(2, 2)))
   self.assertEqual(output_signature.dtype, tf.float16)
   self.assertEqual(output_signature.shape, (2, 5))
   self.assertEqual(outputs.dtype, 'float16')
   self.assertEqual(layer.kernel.dtype, 'float32')
Beispiel #12
0
  def test_get_layer_policy(self):
    layer = core.Dense(4)
    self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float32')

    p = policy.Policy('mixed_float16')
    layer = core.Dense(4, dtype=p)
    self.assertIs(get_layer_policy.get_layer_policy(layer), p)

    layer = core.Dense(4, dtype='float64')
    self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float64')
Beispiel #13
0
 def test_batchnorm_mixed_precision_does_not_overflow(self, fused):
     norm = keras.layers.BatchNormalization(
         axis=-1,
         input_shape=(1, 1, 1),
         fused=fused,
         dtype=policy.Policy('mixed_float16'))
     x = np.array([-1000., 1000.]).reshape((2, 1, 1, 1))
     y = norm(x, training=True)
     expected_y = np.array([-1.0, 1.0]).reshape((2, 1, 1, 1))
     self.assertAllClose(keras.backend.eval(y), expected_y)
Beispiel #14
0
 def test_batchnorm_mixed_precision(self):
     norm = keras.layers.BatchNormalization(
         axis=-1,
         input_shape=(4, 4, 3),
         momentum=0.8,
         dtype=policy.Policy('mixed_float16'))
     x = np.random.normal(size=(10, 4, 4, 3))
     y = norm(x)
     self.assertEqual(y.dtype, 'float16')
     self.assertEqual(norm.beta.dtype.base_dtype, 'float32')
     self.assertEqual(norm.gamma.dtype.base_dtype, 'float32')
Beispiel #15
0
 def test_pass_invalid_optimizer_with_loss_scaling(self):
     with policy.policy_scope(policy.Policy("mixed_float16")):
         x = layers.Input(shape=(1,))
         y = mp_test_util.MultiplyLayer()(x)
         model = models.Model(x, y)
         if tf.executing_eagerly():
             error_msg = "Use a `tf.keras` Optimizer instead"
         else:
             error_msg = 'optimizer" must be an instance of '
         with self.assertRaisesRegex(ValueError, error_msg):
             model.compile(optimizer_v1.SGD(1.0), "mse")
Beispiel #16
0
 def test_repr(self):
     # Test Policy repr
     for policy in (
         "float32",
         "int8",
         "mixed_float16",
         "mixed_bfloat16",
         "_infer",
     ):
         self.assertEqual(
             repr(mp_policy.Policy(policy)), '<Policy "%s">' % policy
         )
Beispiel #17
0
 def test_passing_policy_to_layer(self, strategy_fn):
     x = tf.constant([1.], dtype=tf.float16)
     with strategy_fn().scope():
         # Passing a Policy to 'dtype' sets the policy for that layer.
         layer = mp_test_util.MultiplyLayer(
             assert_type=tf.float16, dtype=policy.Policy('mixed_float16'))
         # layer.dtype refers to the variable dtype
         self.assertEqual(layer.dtype, tf.float32)
         layer(x)
         self.assertEqual(layer.v.dtype, tf.float32)
         with policy.policy_scope('mixed_float16'):
             # Passing a Policy to dtype overrides the global Policy
             layer = mp_test_util.MultiplyLayer(
                 assert_type=tf.float64, dtype=policy.Policy('float64'))
             self.assertEqual(layer.dtype_policy.name, 'float64')
             self.assertIsInstance(layer.dtype_policy, policy.Policy)
             self.assertEqual(layer.compute_dtype, tf.float64)
             self.assertEqual(layer.dtype, tf.float64)
             self.assertEqual(layer.variable_dtype, tf.float64)
             self.assertEqual(layer(x).dtype, tf.float64)
             self.assertEqual(layer.v.dtype, tf.float64)
Beispiel #18
0
  def test_from_config_policy_v1(self, strategy_fn):
    # Test that layers serialized in previous Keras versions with the
    # now-deleted PolicyV1 can be deserialized. In such cases, the PolicyV1 will
    # be converted to a Policy, since PolicyV1 no longer exists. Unlike Policy,
    # PolicyV1 had a "loss_scale" field, which is silently dropped when
    # deserialized.
    x = tf.constant([1.], dtype=tf.float16)
    with strategy_fn().scope():

      layer = mp_test_util.MultiplyLayer(dtype='mixed_float16')
      config = layer.get_config()
      # Change the serialized dtype policy to a PolicyV1
      config['dtype'] = {'class_name': 'PolicyV1',
                         'config': {'name': 'mixed_float16',
                                    'loss_scale': None}}
      layer = mp_test_util.MultiplyLayer.from_config(config)
      self.assertEqual(layer.dtype, 'float32')
      self.assertEqual(layer(x).dtype, 'float16')
      self.assertEqual(layer.v.dtype, 'float32')
      config = layer.get_config()
      # The loss_scale is silently dropped
      self.assertEqual(config['dtype'],
                       {'class_name': 'Policy',
                        'config': {'name': 'mixed_float16'}})

      layer = mp_test_util.MultiplyLayer(dtype='float64')
      config = layer.get_config()
      config['dtype'] = {'class_name': 'PolicyV1',
                         'config': {'name': 'float64',
                                    'loss_scale': {
                                        'class_name': 'FixedLossScale',
                                        'config': {'loss_scale_value': 2.0}}}}
      layer = mp_test_util.MultiplyLayer.from_config(config)
      self.assertEqual(layer.dtype, 'float64')
      self.assertEqual(layer(x).dtype, 'float64')
      self.assertEqual(layer.v.dtype, 'float64')
      config = layer.get_config()
      self.assertEqual(config['dtype'], 'float64')

      layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer'))
      config = layer.get_config()
      config['dtype'] = {'class_name': 'PolicyV1',
                         'config': {'name': '_infer',
                                    'loss_scale': {
                                        'class_name': 'FixedLossScale',
                                        'config': {'loss_scale_value': 2.0}}}}
      layer = mp_test_util.MultiplyLayer.from_config(config)
      self.assertEqual(layer.dtype, None)
      self.assertEqual(layer(x).dtype, 'float16')
      self.assertEqual(layer.v.dtype, 'float16')
      self.assertEqual(type(layer.dtype_policy), policy.Policy)
      config = layer.get_config()
      self.assertEqual(config['dtype'], 'float16')
Beispiel #19
0
 def test_global_policy_dtype_error(self):
     with self.assertRaisesRegex(
             ValueError,
             'set_policy can only be used to set the global policy to '
             'floating-point policies, such as "float32" and "mixed_float16", but '
             'got policy: int32'):
         mp_policy.set_policy('int32')
     with self.assertRaisesRegex(
             ValueError,
             'set_policy can only be used to set the global policy to '
             'floating-point policies, such as "float32" and "mixed_float16", but '
             'got policy: complex64'):
         mp_policy.set_policy(mp_policy.Policy('complex64'))
Beispiel #20
0
  def test_layer_with_int_variable(self):
    class LayerWithIntVar(base_layer.Layer):

      def build(self, _):
        self.v = self.add_weight('v', dtype='int32', trainable=False)

      def call(self, inputs):
        # Only float variables should be autocasted. This will fail if self.v is
        # autocasted to float32
        return tf.cast(inputs, 'int32') + self.v

    x = tf.constant([1.])
    layer = LayerWithIntVar(dtype=policy.Policy('mixed_float16'))
    self.assertEqual(layer(x).dtype, 'int32')
Beispiel #21
0
    def test_serialization(self):
        # Test policies that are equivalent to a single dtype
        for policy_name in "float16", "float32", "int8", "string", "bool":
            policy = mp_policy.Policy(policy_name)
            config = mp_policy.serialize(policy)
            self.assertEqual(config, policy_name)
            new_policy = mp_policy.deserialize(config)
            self.assertEqual(str(policy), str(new_policy))

        # Test "_infer" policy
        policy = mp_policy.Policy("_infer")
        config = mp_policy.serialize(policy)
        self.assertIsNone(config)
        new_policy = mp_policy.deserialize(config)
        self.assertEqual(str(policy), str(new_policy))

        class MyPolicy(mp_policy.Policy):
            pass

        # Test policies that are not equivalent to a single dtype
        for policy in (
            mp_policy.Policy("mixed_float16"),
            mp_policy.Policy("mixed_bfloat16"),
            MyPolicy("float32"),
        ):
            config = mp_policy.serialize(policy)
            self.assertEqual(
                config,
                {
                    "class_name": policy.__class__.__name__,
                    "config": {"name": policy.name},
                },
            )
            new_policy = mp_policy.deserialize(
                config, custom_objects={"MyPolicy": MyPolicy}
            )
            self.assertEqual(str(policy), str(new_policy))
Beispiel #22
0
    def test_policy_errors(self):
        # Test passing invalid strings

        with self.assertRaisesRegex(
            ValueError, "Cannot convert value abc to a mixed precision Policy."
        ):
            mp_policy.Policy("abc")

        # Test passing a DType
        with self.assertRaisesRegex(
            TypeError,
            "'name' must be a string, not a DType. "
            "Instead, pass DType.name. Got: float16",
        ):
            mp_policy.Policy(tf.float16)

        # Test passing a non-DType invalid type
        with self.assertRaisesRegex(
            TypeError, "'name' must be a string, but got: 5"
        ):
            mp_policy.Policy(5)

        # Test passing a now-removed policy ending in float32_vars
        with self.assertRaisesRegex(
            ValueError,
            "Policies ending in '_float32_vars' have been removed "
            "from TensorFlow. Please use the 'mixed_float16' or "
            "'mixed_bfloat16' policy instead. Got policy name: "
            "'infer_float32_vars'",
        ):
            mp_policy.Policy("infer_float32_vars")
        with self.assertRaisesRegex(
            ValueError,
            "Policies ending in '_float32_vars' have been removed "
            "from TensorFlow. Please use the 'mixed_float16' policy "
            "instead. Got policy name: 'float16_with_float32_vars'",
        ):
            mp_policy.Policy("float16_with_float32_vars")
        with self.assertRaisesRegex(
            ValueError,
            "Policies ending in '_float32_vars' have been removed "
            "from TensorFlow. Please use the 'mixed_bfloat16' policy "
            "instead. Got policy name: 'bfloat16_with_float32_vars'",
        ):
            mp_policy.Policy("bfloat16_with_float32_vars")
        with self.assertRaisesRegex(
            ValueError,
            "Policies ending in '_float32_vars' have been removed "
            "from TensorFlow. Got policy name: "
            "'int8_with_float32_vars'",
        ):
            mp_policy.Policy("int8_with_float32_vars")
Beispiel #23
0
 def test_global_policy_dtype_error(self):
     with self.assertRaisesRegex(
         ValueError,
         "set_global_policy can only be used to set the global policy to "
         'floating-point policies, such as "float32" and "mixed_float16", '
         "but got policy: int32",
     ):
         mp_policy.set_global_policy("int32")
     with self.assertRaisesRegex(
         ValueError,
         "set_global_policy can only be used to set the global policy to "
         'floating-point policies, such as "float32" and "mixed_float16", '
         "but got policy: complex64",
     ):
         mp_policy.set_global_policy(mp_policy.Policy("complex64"))
Beispiel #24
0
    def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
        # For backwards compatibility, legacy layers do not use `ResourceVariable`
        # by default.
        self._use_resource_variables = False
        scope = kwargs.pop('_scope', None)
        self._reuse = kwargs.pop('_reuse', None)

        # Avoid an incorrect lint error
        self._trainable_weights = []
        self.built = False

        if dtype is None:
            # Indicates to infer dtype from inputs. When the V2 dtype behavior is
            # enabled, Keras layers default their dtype to floatx instead, so we pass
            # an "_infer" policy to keep the old V1 behavior.
            dtype = policy.Policy('_infer')

        if 'autocast' not in kwargs:
            kwargs['autocast'] = False

        # Mark that legacy layers should not be instrumented as Keras usage
        self._disable_keras_instrumentation = True

        super(Layer, self).__init__(trainable=trainable,
                                    name=name,
                                    dtype=dtype,
                                    **kwargs)

        if _is_in_keras_style_scope():
            if scope is not None:
                raise ValueError(
                    'scope argument not allowed when keras style layers are enabled, '
                    'but saw: {}'.format(scope))
            if self._reuse is not None:
                raise ValueError(
                    'reuse argument not allowed when keras style layers are enabled, '
                    'but saw: {}'.format(self._reuse))
            self._keras_style = True
        else:
            self._keras_style = False

        self._call_has_scope_arg = 'scope' in self._call_fn_args
        if scope:
            with tf.compat.v1.variable_scope(scope) as captured_scope:
                self._scope = captured_scope
        else:
            self._scope = None
        self._current_scope = None
Beispiel #25
0
    def test_repr(self):
        # Test Policy repr
        for policy in ('float32', 'int8', 'mixed_float16', 'mixed_bfloat16',
                       '_infer'):
            self.assertEqual(repr(mp_policy.Policy(policy)),
                             '<Policy "%s">' % policy)

        # Test PolicyV1 repr
        for policy in ('float32', 'int8', 'mixed_bfloat16', '_infer'):
            self.assertEqual(repr(mp_policy.PolicyV1(policy)),
                             '<PolicyV1 "%s", loss_scale=None>' % policy)
        self.assertEqual(
            repr(mp_policy.PolicyV1('float16', loss_scale=2)),
            '<PolicyV1 "float16", loss_scale=FixedLossScale(2.0)>')
        self.assertStartsWith(
            repr(mp_policy.PolicyV1('mixed_float16')),
            '<PolicyV1 "mixed_float16", loss_scale=DynamicLossScale(')
Beispiel #26
0
 def test_global_policy(self):
   if base_layer_utils.v2_dtype_behavior_enabled():
     default_policy = 'float32'
   else:
     default_policy = '_infer'
   self.assertEqual(mp_policy.global_policy().name, default_policy)
   try:
     mp_policy.set_global_policy('mixed_float16')
     self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
     with tf.Graph().as_default():  # Policies are not associated with a graph
       self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
     mp_policy.set_global_policy('_infer')
     self.assertEqual(mp_policy.global_policy().name, '_infer')
     policy = mp_policy.Policy('mixed_bfloat16')
     mp_policy.set_global_policy(policy)
     self.assertIs(mp_policy.global_policy(), policy)
   finally:
     mp_policy.set_global_policy(None)
Beispiel #27
0
    def test_policy_errors(self):
        # Test passing invalid strings

        with self.assertRaisesRegex(
                ValueError,
                'Cannot convert value abc to a mixed precision Policy.'):
            mp_policy.Policy('abc')

        # Test passing a DType
        with self.assertRaisesRegex(
                TypeError, "'name' must be a string, not a DType. "
                'Instead, pass DType.name. Got: float16'):
            mp_policy.Policy(tf.float16)

        # Test passing a non-DType invalid type
        with self.assertRaisesRegex(TypeError,
                                    "'name' must be a string, but got: 5"):
            mp_policy.Policy(5)

        # Test passing a now-removed policy ending in float32_vars
        with self.assertRaisesRegex(
                ValueError,
                'Policies ending in \'_float32_vars\' have been removed '
                'from TensorFlow. Please use the \'mixed_float16\' or '
                '\'mixed_bfloat16\' policy instead. Got policy name: '
                '\'infer_float32_vars\''):
            mp_policy.Policy('infer_float32_vars')
        with self.assertRaisesRegex(
                ValueError,
                'Policies ending in \'_float32_vars\' have been removed '
                'from TensorFlow. Please use the \'mixed_float16\' policy '
                'instead. Got policy name: \'float16_with_float32_vars\''):
            mp_policy.Policy('float16_with_float32_vars')
        with self.assertRaisesRegex(
                ValueError,
                'Policies ending in \'_float32_vars\' have been removed '
                'from TensorFlow. Please use the \'mixed_bfloat16\' policy '
                'instead. Got policy name: \'bfloat16_with_float32_vars\''):
            mp_policy.Policy('bfloat16_with_float32_vars')
        with self.assertRaisesRegex(
                ValueError,
                'Policies ending in \'_float32_vars\' have been removed '
                'from TensorFlow. Got policy name: '
                '\'int8_with_float32_vars\''):
            mp_policy.Policy('int8_with_float32_vars')
Beispiel #28
0
 def test_config(self):
     for policy in (
             mp_policy.Policy('float16'),
             mp_policy.Policy('float32'),
             mp_policy.Policy('int16'),
             mp_policy.Policy('mixed_float16'),
             mp_policy.Policy('mixed_bfloat16'),
             mp_policy.Policy('_infer'),
     ):
         config = policy.get_config()
         new_policy = mp_policy.Policy.from_config(config)
         # Comparing strings is the easiest way to ensure the policies are the
         # same, as policy does not override the == operator.
         self.assertEqual(str(policy), str(new_policy))
Beispiel #29
0
 def test_global_policy(self):
     if base_layer_utils.v2_dtype_behavior_enabled():
         default_policy = "float32"
     else:
         default_policy = "_infer"
     self.assertEqual(mp_policy.global_policy().name, default_policy)
     try:
         mp_policy.set_global_policy("mixed_float16")
         self.assertEqual(mp_policy.global_policy().name, "mixed_float16")
         # Policies are not associated with a graph
         with tf.Graph().as_default():
             self.assertEqual(
                 mp_policy.global_policy().name, "mixed_float16"
             )
         mp_policy.set_global_policy("_infer")
         self.assertEqual(mp_policy.global_policy().name, "_infer")
         policy = mp_policy.Policy("mixed_bfloat16")
         mp_policy.set_global_policy(policy)
         self.assertIs(mp_policy.global_policy(), policy)
     finally:
         mp_policy.set_global_policy(None)
Beispiel #30
0
    def test_layer(self,
                   f32_layer_fn,
                   input_shape,
                   rtol=2e-3,
                   atol=2e-3,
                   input_data=None):
        """Tests a layer by comparing the float32 and mixed precision weights.

    A float32 layer, a mixed precision layer, and a distributed mixed precision
    layer are run. The three layers are identical other than their dtypes and
    distribution strategies. The outputs after predict() and weights after fit()
    are asserted to be close.

    Args:
      f32_layer_fn: A function returning a float32 layer. The other two layers
        will automatically be created from this
      input_shape: The shape of the input to the layer, including the batch
        dimension. Or a list of shapes if the layer takes multiple inputs.
      rtol: The relative tolerance to be asserted.
      atol: The absolute tolerance to be asserted.
      input_data: A Numpy array with the data of the input. If None, input data
        will be randomly generated
    """

        if f32_layer_fn == reshaping.ZeroPadding2D and tf.test.is_built_with_rocm(
        ):
            return
        if isinstance(input_shape[0], int):
            input_shapes = [input_shape]
        else:
            input_shapes = input_shape
        strategy = create_mirrored_strategy()
        f32_layer = f32_layer_fn()

        # Create the layers
        assert f32_layer.dtype == f32_layer._compute_dtype == 'float32'
        config = f32_layer.get_config()
        config['dtype'] = policy.Policy('mixed_float16')
        mp_layer = f32_layer.__class__.from_config(config)
        distributed_mp_layer = f32_layer.__class__.from_config(config)

        # Compute per_replica_input_shapes for the distributed model
        global_batch_size = input_shapes[0][0]
        assert global_batch_size % strategy.num_replicas_in_sync == 0, (
            'The number of replicas, %d, does not divide the global batch size of '
            '%d' % (strategy.num_replicas_in_sync, global_batch_size))
        per_replica_batch_size = (global_batch_size //
                                  strategy.num_replicas_in_sync)
        per_replica_input_shapes = [(per_replica_batch_size, ) + s[1:]
                                    for s in input_shapes]

        # Create the models
        f32_model = self._create_model_from_layer(f32_layer, input_shapes)
        mp_model = self._create_model_from_layer(mp_layer, input_shapes)
        with strategy.scope():
            distributed_mp_model = self._create_model_from_layer(
                distributed_mp_layer, per_replica_input_shapes)

        # Set all model weights to the same values
        f32_weights = f32_model.get_weights()
        mp_model.set_weights(f32_weights)
        distributed_mp_model.set_weights(f32_weights)

        # Generate input data
        if input_data is None:
            # Cast inputs to float16 to avoid measuring error from having f16 layers
            # cast to float16.
            input_data = [
                np.random.normal(size=s).astype('float16')
                for s in input_shapes
            ]
            if len(input_data) == 1:
                input_data = input_data[0]

        # Assert all models have close outputs.
        f32_output = f32_model.predict(input_data)
        mp_output = mp_model.predict(input_data)
        self.assertAllClose(mp_output, f32_output, rtol=rtol, atol=atol)
        self.assertAllClose(distributed_mp_model.predict(input_data),
                            f32_output,
                            rtol=rtol,
                            atol=atol)

        # Run fit() on models
        output = np.random.normal(
            size=f32_model.outputs[0].shape).astype('float16')
        for model in f32_model, mp_model, distributed_mp_model:
            model.fit(input_data, output, batch_size=global_batch_size)

        # Assert all models have close weights
        f32_weights = f32_model.get_weights()
        self.assertAllClose(mp_model.get_weights(),
                            f32_weights,
                            rtol=rtol,
                            atol=atol)
        self.assertAllClose(distributed_mp_model.get_weights(),
                            f32_weights,
                            rtol=rtol,
                            atol=atol)