コード例 #1
0
 def test_mixed_precision_embedding(self):
     try:
         policy.set_policy('mixed_float16')
         layer = keras.layers.Embedding(input_dim=5, output_dim=2)
         self.assertEqual(layer._dtype_policy.name, 'mixed_float16')
         outputs = layer(np.array([0, 1, 2]))
         self.assertEqual(outputs.dtype, 'float16')
     finally:
         policy.set_policy('float32')
コード例 #2
0
 def test_error_if_graph_rewrite_enabled(self):
   try:
     mixed_precision.enable_mixed_precision_graph_rewrite(
         gradient_descent.SGD(1.))
     with self.assertRaisesRegexp(
         ValueError, 'the mixed precision graph rewrite has already been '
                     'enabled'):
       mp_policy.set_policy('infer_float32_vars')
   finally:
     mixed_precision.disable_mixed_precision_graph_rewrite()
コード例 #3
0
 def test_error_if_graph_rewrite_enabled(self):
   try:
     mixed_precision.enable_mixed_precision_graph_rewrite(
         gradient_descent.SGD(1.))
     with self.assertRaisesRegex(
         ValueError, 'cannot be set to "mixed_float16", .* the mixed '
         'precision graph rewrite has already been enabled'):
       mp_policy.set_policy('mixed_float16')
     with mp_policy.policy_scope('float64'):
       pass  # Non-mixed policies are allowed
   finally:
     mixed_precision.disable_mixed_precision_graph_rewrite()
コード例 #4
0
 def test_global_policy_dtype_error(self):
   with self.assertRaisesRegex(
       ValueError,
       'set_policy can only be used to set the global policy to '
       'floating-point policies, such as "float32" and "mixed_float16", but '
       'got policy: int32'):
     mp_policy.set_policy('int32')
   with self.assertRaisesRegex(
       ValueError,
       'set_policy can only be used to set the global policy to '
       'floating-point policies, such as "float32" and "mixed_float16", but '
       'got policy: complex64'):
     mp_policy.set_policy(mp_policy.Policy('complex64'))
コード例 #5
0
 def test_global_policy(self):
   if base_layer_utils.v2_dtype_behavior_enabled():
     default_policy = 'float32'
   else:
     default_policy = '_infer'
   self.assertEqual(mp_policy.global_policy().name, default_policy)
   try:
     mp_policy.set_policy('mixed_float16')
     self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
     with ops.Graph().as_default():  # Policies are not associated with a graph
       self.assertEqual(mp_policy.global_policy().name, 'mixed_float16')
     mp_policy.set_policy('_infer')
     self.assertEqual(mp_policy.global_policy().name, '_infer')
     policy = mp_policy.Policy('mixed_bfloat16')
     mp_policy.set_policy(policy)
     self.assertIs(mp_policy.global_policy(), policy)
   finally:
     mp_policy.set_policy(None)
コード例 #6
0
 def test_global_policy(self):
   self.assertEqual(mp_policy.global_policy().name, 'infer')
   default_policy = mp_policy.global_policy()
   try:
     mp_policy.set_policy('infer_float32_vars')
     self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars')
     self.assertEqual(mp_policy.global_policy().default_variable_dtype,
                      'float32')
     with ops.Graph().as_default():  # Policies are not associated with a graph
       self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars')
     mp_policy.set_policy('infer')
     self.assertEqual(mp_policy.global_policy().name, 'infer')
     self.assertEqual(mp_policy.global_policy().default_variable_dtype, None)
     policy = mp_policy.Policy('infer_float32_vars')
     mp_policy.set_policy(policy)
     self.assertIs(mp_policy.global_policy(), policy)
   finally:
     mp_policy.set_policy(default_policy)