def tearDown(self): # Set the IGNORE_PERF_VAR variable back to it's original value. if self._original_ignore_perf_value is not None: os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value else: del os.environ[self.IGNORE_PERF_VAR] mixed_precision.disable_mixed_precision_graph_rewrite() super(MixedPrecisionTest, self).tearDown()
def tearDown(self): # Set the IGNORE_PERF_VAR variable back to it's original value. if self._original_ignore_perf_value is not None: os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value else: del os.environ[self.IGNORE_PERF_VAR] mixed_precision.disable_mixed_precision_graph_rewrite() super(MixedPrecisionTest, self).tearDown()
def test_error_if_graph_rewrite_enabled(self): try: mixed_precision.enable_mixed_precision_graph_rewrite( gradient_descent.SGD(1.)) with self.assertRaisesRegexp( ValueError, 'the mixed precision graph rewrite has already been ' 'enabled'): mp_policy.set_policy('infer_float32_vars') finally: mixed_precision.disable_mixed_precision_graph_rewrite()
def test_error_if_graph_rewrite_enabled(self): try: mixed_precision.enable_mixed_precision_graph_rewrite( gradient_descent.SGD(1.)) with self.assertRaisesRegex( ValueError, 'cannot be set to "mixed_float16", .* the mixed ' 'precision graph rewrite has already been enabled'): mp_policy.set_policy('mixed_float16') with mp_policy.policy_scope('float64'): pass # Non-mixed policies are allowed finally: mixed_precision.disable_mixed_precision_graph_rewrite()
def test_grappler_pass_enabled(self): opt = gradient_descent_v1.GradientDescentOptimizer(1.0) enable_mixed_precision_graph_rewrite(opt, 123.) var = variables.Variable([[1.0]]) def overflow_in_float16(): out = var * 2**10 out = math_ops.matmul(out, out) return array_ops.reshape(out, ()) if context.executing_eagerly(): f = def_function.function(overflow_in_float16) self.assertEqual(f().numpy(), float('Inf')) # Outside a def_function.function, the grappler pass will not be applied. self.assertAlmostEqual(overflow_in_float16().numpy(), 2**20) # Test disabling mixed precision. mixed_precision.disable_mixed_precision_graph_rewrite() self.assertEqual(f().numpy(), 2**20) else: with session.Session() as sess: out = overflow_in_float16() sess.run(var.initializer) self.assertEqual(sess.run(out), float('Inf')) # Test Session will enable the auto_mixed_precision grappler pass in a # ConfigProto passed by the user with session.Session(config=config_pb2.ConfigProto()) as sess: out = overflow_in_float16() sess.run(var.initializer) self.assertEqual(sess.run(out), float('Inf')) # Test disabling mixed precision. mixed_precision.disable_mixed_precision_graph_rewrite() with session.Session() as sess: out = overflow_in_float16() sess.run(var.initializer) self.assertAlmostEqual(sess.run(out), 2**20)
def test_grappler_pass_enabled(self): opt = gradient_descent_v2.SGD(1.0) enable_mixed_precision_graph_rewrite(opt, 123.) var = variables.Variable([[1.0]]) def overflow_in_float16(): out = var * 2 ** 10 out = math_ops.matmul(out, out) return array_ops.reshape(out, ()) if context.executing_eagerly(): f = def_function.function(overflow_in_float16) self.assertEqual(f().numpy(), float('Inf')) # Outside a def_function.function, the grappler pass will not be applied. self.assertAlmostEqual(overflow_in_float16().numpy(), 2 ** 20) # Test disabling mixed precision. mixed_precision.disable_mixed_precision_graph_rewrite() self.assertEqual(f().numpy(), 2 ** 20) else: with session.Session() as sess: out = overflow_in_float16() sess.run(var.initializer) self.assertEqual(sess.run(out), float('Inf')) # Test Session will enable the auto_mixed_precision grappler pass in a # ConfigProto passed by the user with session.Session(config=config_pb2.ConfigProto()) as sess: out = overflow_in_float16() sess.run(var.initializer) self.assertEqual(sess.run(out), float('Inf')) # Test disabling mixed precision. mixed_precision.disable_mixed_precision_graph_rewrite() with session.Session() as sess: out = overflow_in_float16() sess.run(var.initializer) self.assertAlmostEqual(sess.run(out), 2 ** 20)