예제 #1
0
 def test_error_if_policy_is_set(self):
   with policy.policy_scope('infer_float32_vars'):
     with self.assertRaisesRegexp(
         ValueError, 'a keras mixed precision Policy has been set'):
       mixed_precision.enable_mixed_precision_graph_rewrite(
           gradient_descent_v2.SGD(1.0))
   # Test no error is thrown when the policy is current the default.
   mixed_precision.enable_mixed_precision_graph_rewrite(
       gradient_descent_v2.SGD(1.0))
예제 #2
0
 def test_error_if_graph_rewrite_enabled(self):
   try:
     mixed_precision.enable_mixed_precision_graph_rewrite(
         gradient_descent.SGD(1.))
     with self.assertRaisesRegexp(
         ValueError, 'the mixed precision graph rewrite has already been '
                     'enabled'):
       mp_policy.set_policy('infer_float32_vars')
   finally:
     mixed_precision.disable_mixed_precision_graph_rewrite()
    def test_wrap_optimizer(self):
        opt = gradient_descent_v1.GradientDescentOptimizer(1.0)
        opt = mixed_precision.enable_mixed_precision_graph_rewrite(opt, 123.)
        self.assertIsInstance(
            opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer)
        self.assertEqual(self.evaluate(opt._loss_scale()), 123.)

        opt = gradient_descent_v2.SGD(1.0)
        opt = mixed_precision.enable_mixed_precision_graph_rewrite(opt, 123.)
        self.assertIsInstance(opt, loss_scale_optimizer_v2.LossScaleOptimizer)
        self.assertEqual(self.evaluate(opt._loss_scale()), 123.)
예제 #4
0
 def test_error_if_graph_rewrite_enabled(self):
   try:
     mixed_precision.enable_mixed_precision_graph_rewrite(
         gradient_descent.SGD(1.))
     with self.assertRaisesRegex(
         ValueError, 'cannot be set to "mixed_float16", .* the mixed '
         'precision graph rewrite has already been enabled'):
       mp_policy.set_policy('mixed_float16')
     with mp_policy.policy_scope('float64'):
       pass  # Non-mixed policies are allowed
   finally:
     mixed_precision.disable_mixed_precision_graph_rewrite()
    def test_warn_if_session_already_exists(self, mock_warn):
        # Set this to False, so Sessions created in previous tests do not trigger
        # the warning.
        mixed_precision_global_state.non_mixed_precision_session_created = False

        with session.Session():
            mixed_precision.enable_mixed_precision_graph_rewrite(
                gradient_descent_v2.SGD(1.0))
            mock_warn.assert_any_call(
                'You already have existing Sessions that do not use mixed precision. '
                'enable_mixed_precision_graph_rewrite() will not affect these '
                'Sessions.')
예제 #6
0
  def test_do_not_warn_if_session_does_not_already_exist(self, mock_warn):
    # Set this to False, so Sessions created in previous tests do not trigger
    # the warning.
    mixed_precision_global_state.non_mixed_precision_session_created = False

    mixed_precision.enable_mixed_precision_graph_rewrite(
        gradient_descent_v2.SGD(1.0))
    with session.Session():
      # Make sure the "You already have existing Sessions" warning was not
      # issued, since the Session was only created after
      # enable_mixed_precision_graph_rewrite.
      for call_arg in mock_warn.call_args_list:
        msg = call_arg[0][0]
        self.assertNotIn('You already have existing Sessions that do not use '
                         'mixed precision', msg)
    def test_optimizer_errors(self):
        opt = 1
        expected_regex = (
            '"opt" must be an instance of a tf.train.Optimizer or '
            'a tf.keras.optimizers.Optimizer, but got')
        with self.assertRaisesRegexp(ValueError, expected_regex):
            mixed_precision.enable_mixed_precision_graph_rewrite(opt)
        self.assertFalse(config.get_optimizer_experimental_options().get(
            'auto_mixed_precision', False))

        opt = gradient_descent_v1.GradientDescentOptimizer(1.0)
        opt = loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer(
            opt, 'dynamic')
        with self.assertRaisesRegexp(
                ValueError, '"opt" must not already be an instance of a '
                'MixedPrecisionLossScaleOptimizer.'):
            mixed_precision.enable_mixed_precision_graph_rewrite(opt)
        self.assertFalse(config.get_optimizer_experimental_options().get(
            'auto_mixed_precision', False))

        opt = gradient_descent_v2.SGD(1.0)
        opt = loss_scale_optimizer_v2.LossScaleOptimizer(opt, 'dynamic')
        with self.assertRaisesRegexp(
                ValueError, '"opt" must not already be an instance of a '
                'LossScaleOptimizer.'):
            mixed_precision.enable_mixed_precision_graph_rewrite(opt)
        self.assertFalse(config.get_optimizer_experimental_options().get(
            'auto_mixed_precision', False))
    def test_grappler_pass_enabled(self):
        opt = gradient_descent_v2.SGD(1.0)
        mixed_precision.enable_mixed_precision_graph_rewrite(opt, 123.)

        var = variables.Variable([[1.0]])

        def overflow_in_float16():
            out = var * 2**10
            out = math_ops.matmul(out, out)
            return array_ops.reshape(out, ())

        if context.executing_eagerly():
            f = def_function.function(overflow_in_float16)
            self.assertEqual(f().numpy(), float('Inf'))
            # Outside a def_function.function, the grappler pass will not be applied.
            self.assertAlmostEqual(overflow_in_float16().numpy(), 2**20)

            # Test disabling mixed precision.
            mixed_precision.disable_mixed_precision_graph_rewrite()
            self.assertEqual(f().numpy(), 2**20)
        else:
            with session.Session() as sess:
                out = overflow_in_float16()
                sess.run(var.initializer)
                self.assertEqual(sess.run(out), float('Inf'))

            # Test Session will enable the auto_mixed_precision grappler pass in a
            # ConfigProto passed by the user
            with session.Session(config=config_pb2.ConfigProto()) as sess:
                out = overflow_in_float16()
                sess.run(var.initializer)
                self.assertEqual(sess.run(out), float('Inf'))

            # Test disabling mixed precision.
            mixed_precision.disable_mixed_precision_graph_rewrite()
            with session.Session() as sess:
                out = overflow_in_float16()
                sess.run(var.initializer)
                self.assertAlmostEqual(sess.run(out), 2**20)