示例#1
0
 def test_mixed_precision_embedding(self):
     try:
         policy.set_policy('mixed_float16')
         layer = keras.layers.Embedding(input_dim=5, output_dim=2)
         self.assertEqual(layer._dtype_policy.name, 'mixed_float16')
         outputs = layer(np.array([0, 1, 2]))
         self.assertEqual(outputs.dtype, 'float16')
     finally:
         policy.set_policy('float32')
示例#2
0
 def test_global_policy_dtype_error(self):
     with self.assertRaisesRegex(
             ValueError,
             'set_policy can only be used to set the global policy to '
             'floating-point policies, such as "float32" and "mixed_float16", but '
             'got policy: int32'):
         mp_policy.set_policy('int32')
     with self.assertRaisesRegex(
             ValueError,
             'set_policy can only be used to set the global policy to '
             'floating-point policies, such as "float32" and "mixed_float16", but '
             'got policy: complex64'):
         mp_policy.set_policy(mp_policy.Policy('complex64'))
示例#3
0
 def test_error_if_graph_rewrite_enabled(self):
     try:
         mixed_precision.enable_mixed_precision_graph_rewrite(
             gradient_descent.SGD(1.))
         with self.assertRaisesRegex(
                 ValueError,
                 'cannot be set to "mixed_float16", .* the mixed '
                 'precision graph rewrite has already been enabled'):
             mp_policy.set_policy('mixed_float16')
         with mp_policy.policy_scope('float64'):
             pass  # Non-mixed policies are allowed
     finally:
         mixed_precision.disable_mixed_precision_graph_rewrite()
示例#4
0
 def test_global_policy(self):
     if base_layer_utils.v2_dtype_behavior_enabled():
         default_policy = 'float32'
     else:
         default_policy = '_infer'
     self.assertEqual(mp_policy.global_policy().name, default_policy)
     try:
         mp_policy.set_policy('mixed_float16')
         self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
         with ops.Graph().as_default(
         ):  # Policies are not associated with a graph
             self.assertEqual(mp_policy.global_policy().name,
                              'mixed_float16')
         mp_policy.set_policy('_infer')
         self.assertEqual(mp_policy.global_policy().name, '_infer')
         policy = mp_policy.Policy('mixed_bfloat16')
         mp_policy.set_policy(policy)
         self.assertIs(mp_policy.global_policy(), policy)
     finally:
         mp_policy.set_policy(None)