def test_mixed_precision_embedding(self): try: policy.set_policy('mixed_float16') layer = keras.layers.Embedding(input_dim=5, output_dim=2) self.assertEqual(layer._dtype_policy.name, 'mixed_float16') outputs = layer(np.array([0, 1, 2])) self.assertEqual(outputs.dtype, 'float16') finally: policy.set_policy('float32')
def test_global_policy_dtype_error(self): with self.assertRaisesRegex( ValueError, 'set_policy can only be used to set the global policy to ' 'floating-point policies, such as "float32" and "mixed_float16", but ' 'got policy: int32'): mp_policy.set_policy('int32') with self.assertRaisesRegex( ValueError, 'set_policy can only be used to set the global policy to ' 'floating-point policies, such as "float32" and "mixed_float16", but ' 'got policy: complex64'): mp_policy.set_policy(mp_policy.Policy('complex64'))
def test_error_if_graph_rewrite_enabled(self): try: mixed_precision.enable_mixed_precision_graph_rewrite( gradient_descent.SGD(1.)) with self.assertRaisesRegex( ValueError, 'cannot be set to "mixed_float16", .* the mixed ' 'precision graph rewrite has already been enabled'): mp_policy.set_policy('mixed_float16') with mp_policy.policy_scope('float64'): pass # Non-mixed policies are allowed finally: mixed_precision.disable_mixed_precision_graph_rewrite()
def test_global_policy(self): if base_layer_utils.v2_dtype_behavior_enabled(): default_policy = 'float32' else: default_policy = '_infer' self.assertEqual(mp_policy.global_policy().name, default_policy) try: mp_policy.set_policy('mixed_float16') self.assertEqual(mp_policy.global_policy().name, 'mixed_float16') with ops.Graph().as_default( ): # Policies are not associated with a graph self.assertEqual(mp_policy.global_policy().name, 'mixed_float16') mp_policy.set_policy('_infer') self.assertEqual(mp_policy.global_policy().name, '_infer') policy = mp_policy.Policy('mixed_bfloat16') mp_policy.set_policy(policy) self.assertIs(mp_policy.global_policy(), policy) finally: mp_policy.set_policy(None)