Esempio n. 1
0
    def tearDown(self):
        # Set the IGNORE_PERF_VAR variable back to it's original value.
        if self._original_ignore_perf_value is not None:
            os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value
        else:
            del os.environ[self.IGNORE_PERF_VAR]

        mixed_precision.disable_mixed_precision_graph_rewrite()
        super(MixedPrecisionTest, self).tearDown()
  def tearDown(self):
    # Set the IGNORE_PERF_VAR variable back to it's original value.
    if self._original_ignore_perf_value is not None:
      os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value
    else:
      del os.environ[self.IGNORE_PERF_VAR]

    mixed_precision.disable_mixed_precision_graph_rewrite()
    super(MixedPrecisionTest, self).tearDown()
Esempio n. 3
0
 def test_error_if_graph_rewrite_enabled(self):
   try:
     mixed_precision.enable_mixed_precision_graph_rewrite(
         gradient_descent.SGD(1.))
     with self.assertRaisesRegexp(
         ValueError, 'the mixed precision graph rewrite has already been '
                     'enabled'):
       mp_policy.set_policy('infer_float32_vars')
   finally:
     mixed_precision.disable_mixed_precision_graph_rewrite()
Esempio n. 4
0
 def test_error_if_graph_rewrite_enabled(self):
   try:
     mixed_precision.enable_mixed_precision_graph_rewrite(
         gradient_descent.SGD(1.))
     with self.assertRaisesRegex(
         ValueError, 'cannot be set to "mixed_float16", .* the mixed '
         'precision graph rewrite has already been enabled'):
       mp_policy.set_policy('mixed_float16')
     with mp_policy.policy_scope('float64'):
       pass  # Non-mixed policies are allowed
   finally:
     mixed_precision.disable_mixed_precision_graph_rewrite()
Esempio n. 5
0
    def test_grappler_pass_enabled(self):
        opt = gradient_descent_v1.GradientDescentOptimizer(1.0)
        enable_mixed_precision_graph_rewrite(opt, 123.)

        var = variables.Variable([[1.0]])

        def overflow_in_float16():
            out = var * 2**10
            out = math_ops.matmul(out, out)
            return array_ops.reshape(out, ())

        if context.executing_eagerly():
            f = def_function.function(overflow_in_float16)
            self.assertEqual(f().numpy(), float('Inf'))
            # Outside a def_function.function, the grappler pass will not be applied.
            self.assertAlmostEqual(overflow_in_float16().numpy(), 2**20)

            # Test disabling mixed precision.
            mixed_precision.disable_mixed_precision_graph_rewrite()
            self.assertEqual(f().numpy(), 2**20)
        else:
            with session.Session() as sess:
                out = overflow_in_float16()
                sess.run(var.initializer)
                self.assertEqual(sess.run(out), float('Inf'))

            # Test Session will enable the auto_mixed_precision grappler pass in a
            # ConfigProto passed by the user
            with session.Session(config=config_pb2.ConfigProto()) as sess:
                out = overflow_in_float16()
                sess.run(var.initializer)
                self.assertEqual(sess.run(out), float('Inf'))

            # Test disabling mixed precision.
            mixed_precision.disable_mixed_precision_graph_rewrite()
            with session.Session() as sess:
                out = overflow_in_float16()
                sess.run(var.initializer)
                self.assertAlmostEqual(sess.run(out), 2**20)
  def test_grappler_pass_enabled(self):
    opt = gradient_descent_v2.SGD(1.0)
    enable_mixed_precision_graph_rewrite(opt, 123.)

    var = variables.Variable([[1.0]])

    def overflow_in_float16():
      out = var * 2 ** 10
      out = math_ops.matmul(out, out)
      return array_ops.reshape(out, ())

    if context.executing_eagerly():
      f = def_function.function(overflow_in_float16)
      self.assertEqual(f().numpy(), float('Inf'))
      # Outside a def_function.function, the grappler pass will not be applied.
      self.assertAlmostEqual(overflow_in_float16().numpy(), 2 ** 20)

      # Test disabling mixed precision.
      mixed_precision.disable_mixed_precision_graph_rewrite()
      self.assertEqual(f().numpy(), 2 ** 20)
    else:
      with session.Session() as sess:
        out = overflow_in_float16()
        sess.run(var.initializer)
        self.assertEqual(sess.run(out), float('Inf'))

      # Test Session will enable the auto_mixed_precision grappler pass in a
      # ConfigProto passed by the user
      with session.Session(config=config_pb2.ConfigProto()) as sess:
        out = overflow_in_float16()
        sess.run(var.initializer)
        self.assertEqual(sess.run(out), float('Inf'))

      # Test disabling mixed precision.
      mixed_precision.disable_mixed_precision_graph_rewrite()
      with session.Session() as sess:
        out = overflow_in_float16()
        sess.run(var.initializer)
        self.assertAlmostEqual(sess.run(out), 2 ** 20)