def test_serialization(self): # Test policies that are equivalent to a single dtype for policy_name in 'float16', 'float32', 'int8', 'string', 'bool': policy = mp_policy.Policy(policy_name) config = mp_policy.serialize(policy) self.assertEqual(config, policy_name) new_policy = mp_policy.deserialize(config) self.assertEqual(str(policy), str(new_policy)) # Test "_infer" policy policy = mp_policy.Policy('_infer') config = mp_policy.serialize(policy) self.assertIsNone(config) new_policy = mp_policy.deserialize(config) self.assertEqual(str(policy), str(new_policy)) class MyPolicy(mp_policy.Policy): pass # Test policies that are not equivalent to a single dtype for policy in ( mp_policy.Policy('mixed_float16'), mp_policy.Policy('mixed_bfloat16'), MyPolicy('float32') ): config = mp_policy.serialize(policy) self.assertEqual(config, {'class_name': policy.__class__.__name__, 'config': {'name': policy.name}}) new_policy = mp_policy.deserialize(config, custom_objects={'MyPolicy': MyPolicy}) self.assertEqual(str(policy), str(new_policy))
def test_device_compatibility_warning(self): if not tf.executing_eagerly(): self.skipTest("Run in eager mode only.") device_compatibility_check._logged_compatibility_check = False with tf.compat.v1.test.mock.patch.object( tf_logging, "warning" ) as mock_warn: mp_policy.Policy("mixed_float16") if tf.config.list_physical_devices("GPU"): mock_warn.assert_not_called() else: self.assertRegex( mock_warn.call_args[0][0], r"Mixed precision compatibility check \(mixed_float16\): " r"WARNING.*", ) if tf.config.list_physical_devices("GPU"): # Assert message is only logged once with tf.compat.v1.test.mock.patch.object( tf_logging, "warning" ) as mock_warn: mp_policy.Policy("mixed_float16") mock_warn.assert_not_called()
def test_config(self, strategy_fn): x = tf.constant([1.0], dtype=tf.float16) with strategy_fn().scope(): for layer, dtype in ( (mp_test_util.MultiplyLayer(), "float32"), (mp_test_util.MultiplyLayer(dtype="float64"), "float64"), ( mp_test_util.MultiplyLayer(dtype=policy.Policy("float64")), "float64", ), ): config = layer.get_config() self.assertEqual(config["dtype"], dtype) self.assertIsInstance(config["dtype"], str) layer = mp_test_util.MultiplyLayer.from_config(config) self.assertEqual(layer.dtype, dtype) self.assertEqual(layer(x).dtype, dtype) self.assertEqual(layer.v.dtype, dtype) layer = mp_test_util.MultiplyLayer(dtype="mixed_float16") config = layer.get_config() self.assertEqual( config["dtype"], { "class_name": "Policy", "config": { "name": "mixed_float16" } }, ) layer = mp_test_util.MultiplyLayer.from_config(config) self.assertEqual(layer.dtype, "float32") self.assertEqual(layer(x).dtype, "float16") self.assertEqual(layer.v.dtype, "float32") config = layer.get_config() self.assertEqual( config["dtype"], { "class_name": "Policy", "config": { "name": "mixed_float16" } }, ) layer = mp_test_util.MultiplyLayer(dtype=policy.Policy("_infer")) config = layer.get_config() self.assertIsNone(config["dtype"]) layer = mp_test_util.MultiplyLayer.from_config(config) # If a layer is serialized with the "_infer" policy, when # deserialized into TF 2 it will have the global policy instead of # "_infer". This is because "_infer" is serialized into None, and # passing dtype=None in TensorFlow 2 indicates to use the global # policy. self.assertEqual(layer.dtype, "float32") self.assertEqual(layer(x).dtype, "float32") self.assertEqual(layer.v.dtype, "float32")
def test_unsupported_strategy(self): strategy = create_central_storage_strategy() with strategy.scope(), self.assertRaisesRegex( ValueError, 'Mixed precision is not supported with the ' 'tf.distribute.Strategy: CentralStorageStrategy. Either ' 'stop using mixed precision by removing the use of the ' '"mixed_float16" policy or use a different Strategy, e.g. ' 'a MirroredStrategy.'): mp_test_util.MultiplyLayer(dtype=policy.Policy('mixed_float16')) # Non-mixed policies are fine mp_test_util.MultiplyLayer(dtype=policy.Policy('float64'))
def test_v1_dtype_behavior(self): # Setting global policies are not allowed with V1 dtype behavior with self.assertRaisesRegex( ValueError, 'global policy can only be set in TensorFlow 2'): with mp_policy.policy_scope(mp_policy.Policy('_infer')): pass with self.assertRaisesRegex( ValueError, 'global policy can only be set in TensorFlow 2'): with mp_policy.policy_scope(mp_policy.Policy('float32')): pass with self.assertRaisesRegex( ValueError, 'global policy can only be set in TensorFlow 2'): with mp_policy.policy_scope(mp_policy.Policy('mixed_float16')): pass
def test_dtype_attributes(self): for dtype in 'int32', 'bool', 'float16', 'float32': policy = mp_policy.Policy(dtype) self.assertEqual(policy.name, dtype) self.assertEqual(policy.compute_dtype, dtype) self.assertEqual(policy.variable_dtype, dtype) for dtype in 'float16', 'bfloat16': policy = mp_policy.Policy('mixed_' + dtype) self.assertEqual(policy.name, 'mixed_' + dtype) self.assertEqual(policy.compute_dtype, dtype) self.assertEqual(policy.variable_dtype, 'float32') policy = mp_policy.Policy('_infer') self.assertEqual(policy.compute_dtype, None) self.assertEqual(policy.variable_dtype, None)
def test_dtype_attributes(self): for dtype in "int32", "bool", "float16", "float32": policy = mp_policy.Policy(dtype) self.assertEqual(policy.name, dtype) self.assertEqual(policy.compute_dtype, dtype) self.assertEqual(policy.variable_dtype, dtype) for dtype in "float16", "bfloat16": policy = mp_policy.Policy("mixed_" + dtype) self.assertEqual(policy.name, "mixed_" + dtype) self.assertEqual(policy.compute_dtype, dtype) self.assertEqual(policy.variable_dtype, "float32") policy = mp_policy.Policy("_infer") self.assertEqual(policy.compute_dtype, None) self.assertEqual(policy.variable_dtype, None)
def test_save_slot_variables_with_autocast_vars(self, strategy_fn, var_name='v'): p = policy.Policy('mixed_float16') with strategy_fn().scope(), policy.policy_scope(p): x = layers.Input(shape=(2,), batch_size=2) # Having a var_name other than 'v' tests that a fixed bug (b/134713714) # does not reoccur. The bug was that a crash would occur when saving a # checkpoint where an AutoCastVariable with a slot variable would have a # different name than the layer attribute's name (layer.v in this case). layer = mp_test_util.MultiplyLayer(assert_type=tf.float16, var_name=var_name) y = layer(x) model = models.Model(inputs=x, outputs=y) opt = gradient_descent.SGD(1., 1.) opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, initial_scale=1) model.compile( optimizer=opt, loss='mse', run_eagerly=testing_utils.should_run_eagerly()) model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2) weights_file = os.path.join(self.get_temp_dir(), 'weights') model.save_weights(weights_file) saved_slot = backend.get_value(opt.get_slot(layer.v, 'momentum')) model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2) new_slot = backend.get_value(opt.get_slot(layer.v, 'momentum')) self.assertNotEqual(new_slot, saved_slot) model.load_weights(weights_file) restored_slot = backend.get_value(opt.get_slot(layer.v, 'momentum')) self.assertEqual(restored_slot, saved_slot)
def test_config(self, strategy_fn): x = tf.constant([1.], dtype=tf.float16) with strategy_fn().scope(): for layer, dtype in ((mp_test_util.MultiplyLayer(), 'float32'), (mp_test_util.MultiplyLayer(dtype='float64'), 'float64'), (mp_test_util.MultiplyLayer( dtype=policy.Policy('float64')), 'float64')): config = layer.get_config() self.assertEqual(config['dtype'], dtype) self.assertIsInstance(config['dtype'], str) layer = mp_test_util.MultiplyLayer.from_config(config) self.assertEqual(layer.dtype, dtype) self.assertEqual(layer(x).dtype, dtype) self.assertEqual(layer.v.dtype, dtype) layer = mp_test_util.MultiplyLayer(dtype='mixed_float16') config = layer.get_config() self.assertEqual(config['dtype'], { 'class_name': 'Policy', 'config': { 'name': 'mixed_float16' } }) layer = mp_test_util.MultiplyLayer.from_config(config) self.assertEqual(layer.dtype, 'float32') self.assertEqual(layer(x).dtype, 'float16') self.assertEqual(layer.v.dtype, 'float32') config = layer.get_config() self.assertEqual(config['dtype'], { 'class_name': 'Policy', 'config': { 'name': 'mixed_float16' } }) layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer')) config = layer.get_config() self.assertIsNone(config['dtype']) layer = mp_test_util.MultiplyLayer.from_config(config) # If a layer is serialized with the "_infer" policy, when deserialized # into TF 2 it will have the global policy instead of "_infer". This is # because "_infer" is serialized into None, and passing dtype=None in # TensorFlow 2 indicates to use the global policy. self.assertEqual(layer.dtype, 'float32') self.assertEqual(layer(x).dtype, 'float32') self.assertEqual(layer.v.dtype, 'float32')
def test_build_and_call_layer_in_function(self): layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('mixed_float16')) @tf.function def f(): return layer(1.) y = f() self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertEqual(y.dtype, 'float16') self.assertEqual(layer.v.dtype, 'float32') self.assertEqual(self.evaluate(y), 1.)
def test_dense_with_policy(self): inputs = tf.convert_to_tensor(np.random.randint(low=0, high=7, size=(2, 2))) layer = keras.layers.Dense(5, dtype=policy.Policy('mixed_float16')) outputs = layer(inputs) output_signature = layer.compute_output_signature( tf.TensorSpec(dtype='float16', shape=(2, 2))) self.assertEqual(output_signature.dtype, tf.float16) self.assertEqual(output_signature.shape, (2, 5)) self.assertEqual(outputs.dtype, 'float16') self.assertEqual(layer.kernel.dtype, 'float32')
def test_get_layer_policy(self): layer = core.Dense(4) self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float32') p = policy.Policy('mixed_float16') layer = core.Dense(4, dtype=p) self.assertIs(get_layer_policy.get_layer_policy(layer), p) layer = core.Dense(4, dtype='float64') self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 'float64')
def test_batchnorm_mixed_precision_does_not_overflow(self, fused): norm = keras.layers.BatchNormalization( axis=-1, input_shape=(1, 1, 1), fused=fused, dtype=policy.Policy('mixed_float16')) x = np.array([-1000., 1000.]).reshape((2, 1, 1, 1)) y = norm(x, training=True) expected_y = np.array([-1.0, 1.0]).reshape((2, 1, 1, 1)) self.assertAllClose(keras.backend.eval(y), expected_y)
def test_batchnorm_mixed_precision(self): norm = keras.layers.BatchNormalization( axis=-1, input_shape=(4, 4, 3), momentum=0.8, dtype=policy.Policy('mixed_float16')) x = np.random.normal(size=(10, 4, 4, 3)) y = norm(x) self.assertEqual(y.dtype, 'float16') self.assertEqual(norm.beta.dtype.base_dtype, 'float32') self.assertEqual(norm.gamma.dtype.base_dtype, 'float32')
def test_pass_invalid_optimizer_with_loss_scaling(self): with policy.policy_scope(policy.Policy("mixed_float16")): x = layers.Input(shape=(1,)) y = mp_test_util.MultiplyLayer()(x) model = models.Model(x, y) if tf.executing_eagerly(): error_msg = "Use a `tf.keras` Optimizer instead" else: error_msg = 'optimizer" must be an instance of ' with self.assertRaisesRegex(ValueError, error_msg): model.compile(optimizer_v1.SGD(1.0), "mse")
def test_repr(self): # Test Policy repr for policy in ( "float32", "int8", "mixed_float16", "mixed_bfloat16", "_infer", ): self.assertEqual( repr(mp_policy.Policy(policy)), '<Policy "%s">' % policy )
def test_passing_policy_to_layer(self, strategy_fn): x = tf.constant([1.], dtype=tf.float16) with strategy_fn().scope(): # Passing a Policy to 'dtype' sets the policy for that layer. layer = mp_test_util.MultiplyLayer( assert_type=tf.float16, dtype=policy.Policy('mixed_float16')) # layer.dtype refers to the variable dtype self.assertEqual(layer.dtype, tf.float32) layer(x) self.assertEqual(layer.v.dtype, tf.float32) with policy.policy_scope('mixed_float16'): # Passing a Policy to dtype overrides the global Policy layer = mp_test_util.MultiplyLayer( assert_type=tf.float64, dtype=policy.Policy('float64')) self.assertEqual(layer.dtype_policy.name, 'float64') self.assertIsInstance(layer.dtype_policy, policy.Policy) self.assertEqual(layer.compute_dtype, tf.float64) self.assertEqual(layer.dtype, tf.float64) self.assertEqual(layer.variable_dtype, tf.float64) self.assertEqual(layer(x).dtype, tf.float64) self.assertEqual(layer.v.dtype, tf.float64)
def test_from_config_policy_v1(self, strategy_fn): # Test that layers serialized in previous Keras versions with the # now-deleted PolicyV1 can be deserialized. In such cases, the PolicyV1 will # be converted to a Policy, since PolicyV1 no longer exists. Unlike Policy, # PolicyV1 had a "loss_scale" field, which is silently dropped when # deserialized. x = tf.constant([1.], dtype=tf.float16) with strategy_fn().scope(): layer = mp_test_util.MultiplyLayer(dtype='mixed_float16') config = layer.get_config() # Change the serialized dtype policy to a PolicyV1 config['dtype'] = {'class_name': 'PolicyV1', 'config': {'name': 'mixed_float16', 'loss_scale': None}} layer = mp_test_util.MultiplyLayer.from_config(config) self.assertEqual(layer.dtype, 'float32') self.assertEqual(layer(x).dtype, 'float16') self.assertEqual(layer.v.dtype, 'float32') config = layer.get_config() # The loss_scale is silently dropped self.assertEqual(config['dtype'], {'class_name': 'Policy', 'config': {'name': 'mixed_float16'}}) layer = mp_test_util.MultiplyLayer(dtype='float64') config = layer.get_config() config['dtype'] = {'class_name': 'PolicyV1', 'config': {'name': 'float64', 'loss_scale': { 'class_name': 'FixedLossScale', 'config': {'loss_scale_value': 2.0}}}} layer = mp_test_util.MultiplyLayer.from_config(config) self.assertEqual(layer.dtype, 'float64') self.assertEqual(layer(x).dtype, 'float64') self.assertEqual(layer.v.dtype, 'float64') config = layer.get_config() self.assertEqual(config['dtype'], 'float64') layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer')) config = layer.get_config() config['dtype'] = {'class_name': 'PolicyV1', 'config': {'name': '_infer', 'loss_scale': { 'class_name': 'FixedLossScale', 'config': {'loss_scale_value': 2.0}}}} layer = mp_test_util.MultiplyLayer.from_config(config) self.assertEqual(layer.dtype, None) self.assertEqual(layer(x).dtype, 'float16') self.assertEqual(layer.v.dtype, 'float16') self.assertEqual(type(layer.dtype_policy), policy.Policy) config = layer.get_config() self.assertEqual(config['dtype'], 'float16')
def test_global_policy_dtype_error(self): with self.assertRaisesRegex( ValueError, 'set_policy can only be used to set the global policy to ' 'floating-point policies, such as "float32" and "mixed_float16", but ' 'got policy: int32'): mp_policy.set_policy('int32') with self.assertRaisesRegex( ValueError, 'set_policy can only be used to set the global policy to ' 'floating-point policies, such as "float32" and "mixed_float16", but ' 'got policy: complex64'): mp_policy.set_policy(mp_policy.Policy('complex64'))
def test_layer_with_int_variable(self): class LayerWithIntVar(base_layer.Layer): def build(self, _): self.v = self.add_weight('v', dtype='int32', trainable=False) def call(self, inputs): # Only float variables should be autocasted. This will fail if self.v is # autocasted to float32 return tf.cast(inputs, 'int32') + self.v x = tf.constant([1.]) layer = LayerWithIntVar(dtype=policy.Policy('mixed_float16')) self.assertEqual(layer(x).dtype, 'int32')
def test_serialization(self): # Test policies that are equivalent to a single dtype for policy_name in "float16", "float32", "int8", "string", "bool": policy = mp_policy.Policy(policy_name) config = mp_policy.serialize(policy) self.assertEqual(config, policy_name) new_policy = mp_policy.deserialize(config) self.assertEqual(str(policy), str(new_policy)) # Test "_infer" policy policy = mp_policy.Policy("_infer") config = mp_policy.serialize(policy) self.assertIsNone(config) new_policy = mp_policy.deserialize(config) self.assertEqual(str(policy), str(new_policy)) class MyPolicy(mp_policy.Policy): pass # Test policies that are not equivalent to a single dtype for policy in ( mp_policy.Policy("mixed_float16"), mp_policy.Policy("mixed_bfloat16"), MyPolicy("float32"), ): config = mp_policy.serialize(policy) self.assertEqual( config, { "class_name": policy.__class__.__name__, "config": {"name": policy.name}, }, ) new_policy = mp_policy.deserialize( config, custom_objects={"MyPolicy": MyPolicy} ) self.assertEqual(str(policy), str(new_policy))
def test_policy_errors(self): # Test passing invalid strings with self.assertRaisesRegex( ValueError, "Cannot convert value abc to a mixed precision Policy." ): mp_policy.Policy("abc") # Test passing a DType with self.assertRaisesRegex( TypeError, "'name' must be a string, not a DType. " "Instead, pass DType.name. Got: float16", ): mp_policy.Policy(tf.float16) # Test passing a non-DType invalid type with self.assertRaisesRegex( TypeError, "'name' must be a string, but got: 5" ): mp_policy.Policy(5) # Test passing a now-removed policy ending in float32_vars with self.assertRaisesRegex( ValueError, "Policies ending in '_float32_vars' have been removed " "from TensorFlow. Please use the 'mixed_float16' or " "'mixed_bfloat16' policy instead. Got policy name: " "'infer_float32_vars'", ): mp_policy.Policy("infer_float32_vars") with self.assertRaisesRegex( ValueError, "Policies ending in '_float32_vars' have been removed " "from TensorFlow. Please use the 'mixed_float16' policy " "instead. Got policy name: 'float16_with_float32_vars'", ): mp_policy.Policy("float16_with_float32_vars") with self.assertRaisesRegex( ValueError, "Policies ending in '_float32_vars' have been removed " "from TensorFlow. Please use the 'mixed_bfloat16' policy " "instead. Got policy name: 'bfloat16_with_float32_vars'", ): mp_policy.Policy("bfloat16_with_float32_vars") with self.assertRaisesRegex( ValueError, "Policies ending in '_float32_vars' have been removed " "from TensorFlow. Got policy name: " "'int8_with_float32_vars'", ): mp_policy.Policy("int8_with_float32_vars")
def test_global_policy_dtype_error(self): with self.assertRaisesRegex( ValueError, "set_global_policy can only be used to set the global policy to " 'floating-point policies, such as "float32" and "mixed_float16", ' "but got policy: int32", ): mp_policy.set_global_policy("int32") with self.assertRaisesRegex( ValueError, "set_global_policy can only be used to set the global policy to " 'floating-point policies, such as "float32" and "mixed_float16", ' "but got policy: complex64", ): mp_policy.set_global_policy(mp_policy.Policy("complex64"))
def __init__(self, trainable=True, name=None, dtype=None, **kwargs): # For backwards compatibility, legacy layers do not use `ResourceVariable` # by default. self._use_resource_variables = False scope = kwargs.pop('_scope', None) self._reuse = kwargs.pop('_reuse', None) # Avoid an incorrect lint error self._trainable_weights = [] self.built = False if dtype is None: # Indicates to infer dtype from inputs. When the V2 dtype behavior is # enabled, Keras layers default their dtype to floatx instead, so we pass # an "_infer" policy to keep the old V1 behavior. dtype = policy.Policy('_infer') if 'autocast' not in kwargs: kwargs['autocast'] = False # Mark that legacy layers should not be instrumented as Keras usage self._disable_keras_instrumentation = True super(Layer, self).__init__(trainable=trainable, name=name, dtype=dtype, **kwargs) if _is_in_keras_style_scope(): if scope is not None: raise ValueError( 'scope argument not allowed when keras style layers are enabled, ' 'but saw: {}'.format(scope)) if self._reuse is not None: raise ValueError( 'reuse argument not allowed when keras style layers are enabled, ' 'but saw: {}'.format(self._reuse)) self._keras_style = True else: self._keras_style = False self._call_has_scope_arg = 'scope' in self._call_fn_args if scope: with tf.compat.v1.variable_scope(scope) as captured_scope: self._scope = captured_scope else: self._scope = None self._current_scope = None
def test_repr(self): # Test Policy repr for policy in ('float32', 'int8', 'mixed_float16', 'mixed_bfloat16', '_infer'): self.assertEqual(repr(mp_policy.Policy(policy)), '<Policy "%s">' % policy) # Test PolicyV1 repr for policy in ('float32', 'int8', 'mixed_bfloat16', '_infer'): self.assertEqual(repr(mp_policy.PolicyV1(policy)), '<PolicyV1 "%s", loss_scale=None>' % policy) self.assertEqual( repr(mp_policy.PolicyV1('float16', loss_scale=2)), '<PolicyV1 "float16", loss_scale=FixedLossScale(2.0)>') self.assertStartsWith( repr(mp_policy.PolicyV1('mixed_float16')), '<PolicyV1 "mixed_float16", loss_scale=DynamicLossScale(')
def test_global_policy(self): if base_layer_utils.v2_dtype_behavior_enabled(): default_policy = 'float32' else: default_policy = '_infer' self.assertEqual(mp_policy.global_policy().name, default_policy) try: mp_policy.set_global_policy('mixed_float16') self.assertEqual(mp_policy.global_policy().name, 'mixed_float16') with tf.Graph().as_default(): # Policies are not associated with a graph self.assertEqual(mp_policy.global_policy().name, 'mixed_float16') mp_policy.set_global_policy('_infer') self.assertEqual(mp_policy.global_policy().name, '_infer') policy = mp_policy.Policy('mixed_bfloat16') mp_policy.set_global_policy(policy) self.assertIs(mp_policy.global_policy(), policy) finally: mp_policy.set_global_policy(None)
def test_policy_errors(self): # Test passing invalid strings with self.assertRaisesRegex( ValueError, 'Cannot convert value abc to a mixed precision Policy.'): mp_policy.Policy('abc') # Test passing a DType with self.assertRaisesRegex( TypeError, "'name' must be a string, not a DType. " 'Instead, pass DType.name. Got: float16'): mp_policy.Policy(tf.float16) # Test passing a non-DType invalid type with self.assertRaisesRegex(TypeError, "'name' must be a string, but got: 5"): mp_policy.Policy(5) # Test passing a now-removed policy ending in float32_vars with self.assertRaisesRegex( ValueError, 'Policies ending in \'_float32_vars\' have been removed ' 'from TensorFlow. Please use the \'mixed_float16\' or ' '\'mixed_bfloat16\' policy instead. Got policy name: ' '\'infer_float32_vars\''): mp_policy.Policy('infer_float32_vars') with self.assertRaisesRegex( ValueError, 'Policies ending in \'_float32_vars\' have been removed ' 'from TensorFlow. Please use the \'mixed_float16\' policy ' 'instead. Got policy name: \'float16_with_float32_vars\''): mp_policy.Policy('float16_with_float32_vars') with self.assertRaisesRegex( ValueError, 'Policies ending in \'_float32_vars\' have been removed ' 'from TensorFlow. Please use the \'mixed_bfloat16\' policy ' 'instead. Got policy name: \'bfloat16_with_float32_vars\''): mp_policy.Policy('bfloat16_with_float32_vars') with self.assertRaisesRegex( ValueError, 'Policies ending in \'_float32_vars\' have been removed ' 'from TensorFlow. Got policy name: ' '\'int8_with_float32_vars\''): mp_policy.Policy('int8_with_float32_vars')
def test_config(self): for policy in ( mp_policy.Policy('float16'), mp_policy.Policy('float32'), mp_policy.Policy('int16'), mp_policy.Policy('mixed_float16'), mp_policy.Policy('mixed_bfloat16'), mp_policy.Policy('_infer'), ): config = policy.get_config() new_policy = mp_policy.Policy.from_config(config) # Comparing strings is the easiest way to ensure the policies are the # same, as policy does not override the == operator. self.assertEqual(str(policy), str(new_policy))
def test_global_policy(self): if base_layer_utils.v2_dtype_behavior_enabled(): default_policy = "float32" else: default_policy = "_infer" self.assertEqual(mp_policy.global_policy().name, default_policy) try: mp_policy.set_global_policy("mixed_float16") self.assertEqual(mp_policy.global_policy().name, "mixed_float16") # Policies are not associated with a graph with tf.Graph().as_default(): self.assertEqual( mp_policy.global_policy().name, "mixed_float16" ) mp_policy.set_global_policy("_infer") self.assertEqual(mp_policy.global_policy().name, "_infer") policy = mp_policy.Policy("mixed_bfloat16") mp_policy.set_global_policy(policy) self.assertIs(mp_policy.global_policy(), policy) finally: mp_policy.set_global_policy(None)
def test_layer(self, f32_layer_fn, input_shape, rtol=2e-3, atol=2e-3, input_data=None): """Tests a layer by comparing the float32 and mixed precision weights. A float32 layer, a mixed precision layer, and a distributed mixed precision layer are run. The three layers are identical other than their dtypes and distribution strategies. The outputs after predict() and weights after fit() are asserted to be close. Args: f32_layer_fn: A function returning a float32 layer. The other two layers will automatically be created from this input_shape: The shape of the input to the layer, including the batch dimension. Or a list of shapes if the layer takes multiple inputs. rtol: The relative tolerance to be asserted. atol: The absolute tolerance to be asserted. input_data: A Numpy array with the data of the input. If None, input data will be randomly generated """ if f32_layer_fn == reshaping.ZeroPadding2D and tf.test.is_built_with_rocm( ): return if isinstance(input_shape[0], int): input_shapes = [input_shape] else: input_shapes = input_shape strategy = create_mirrored_strategy() f32_layer = f32_layer_fn() # Create the layers assert f32_layer.dtype == f32_layer._compute_dtype == 'float32' config = f32_layer.get_config() config['dtype'] = policy.Policy('mixed_float16') mp_layer = f32_layer.__class__.from_config(config) distributed_mp_layer = f32_layer.__class__.from_config(config) # Compute per_replica_input_shapes for the distributed model global_batch_size = input_shapes[0][0] assert global_batch_size % strategy.num_replicas_in_sync == 0, ( 'The number of replicas, %d, does not divide the global batch size of ' '%d' % (strategy.num_replicas_in_sync, global_batch_size)) per_replica_batch_size = (global_batch_size // strategy.num_replicas_in_sync) per_replica_input_shapes = [(per_replica_batch_size, ) + s[1:] for s in input_shapes] # Create the models f32_model = self._create_model_from_layer(f32_layer, input_shapes) mp_model = self._create_model_from_layer(mp_layer, input_shapes) with strategy.scope(): distributed_mp_model = self._create_model_from_layer( distributed_mp_layer, per_replica_input_shapes) # Set all model weights to the same values f32_weights = f32_model.get_weights() mp_model.set_weights(f32_weights) distributed_mp_model.set_weights(f32_weights) # Generate input data if input_data is None: # Cast inputs to float16 to avoid measuring error from having f16 layers # cast to float16. input_data = [ np.random.normal(size=s).astype('float16') for s in input_shapes ] if len(input_data) == 1: input_data = input_data[0] # Assert all models have close outputs. f32_output = f32_model.predict(input_data) mp_output = mp_model.predict(input_data) self.assertAllClose(mp_output, f32_output, rtol=rtol, atol=atol) self.assertAllClose(distributed_mp_model.predict(input_data), f32_output, rtol=rtol, atol=atol) # Run fit() on models output = np.random.normal( size=f32_model.outputs[0].shape).astype('float16') for model in f32_model, mp_model, distributed_mp_model: model.fit(input_data, output, batch_size=global_batch_size) # Assert all models have close weights f32_weights = f32_model.get_weights() self.assertAllClose(mp_model.get_weights(), f32_weights, rtol=rtol, atol=atol) self.assertAllClose(distributed_mp_model.get_weights(), f32_weights, rtol=rtol, atol=atol)