def test_no_grad(test_case):
        with flow.no_grad():
            test_case.assertFalse(flow.is_grad_enabled())
        test_case.assertTrue(flow.is_grad_enabled())

        @flow.no_grad()
        def func():
            test_case.assertFalse(flow.is_grad_enabled())

        func()
        test_case.assertTrue(flow.is_grad_enabled())
Beispiel #2
0
    def test_grad_enable(test_case):
        with flow.grad_enable():
            test_case.assertTrue(flow.is_grad_enabled())
        test_case.assertTrue(flow.is_grad_enabled())

        @flow.grad_enable()
        def func():
            test_case.assertTrue(flow.is_grad_enabled())

        func()
        test_case.assertTrue(flow.is_grad_enabled())
    def test_set_grad_enabled(test_case):
        with flow.set_grad_enabled(True):
            test_case.assertTrue(flow.is_grad_enabled())
        test_case.assertTrue(flow.is_grad_enabled())

        @flow.set_grad_enabled(True)
        def func():
            test_case.assertTrue(flow.is_grad_enabled())

        func()
        test_case.assertTrue(flow.is_grad_enabled())

        with flow.set_grad_enabled(False):
            test_case.assertFalse(flow.is_grad_enabled())
        test_case.assertTrue(flow.is_grad_enabled())

        @flow.set_grad_enabled(False)
        def func():
            test_case.assertFalse(flow.is_grad_enabled())

        func()
        test_case.assertTrue(flow.is_grad_enabled())
    def test_inference_mode(test_case):
        with flow.inference_mode(True):
            test_case.assertFalse(flow.is_grad_enabled())
        test_case.assertTrue(flow.is_grad_enabled())

        @flow.inference_mode(True)
        def func():
            test_case.assertFalse(flow.is_grad_enabled())

        func()
        test_case.assertTrue(flow.is_grad_enabled())

        with flow.inference_mode(False):
            test_case.assertTrue(flow.is_grad_enabled())
        test_case.assertTrue(flow.is_grad_enabled())

        @flow.inference_mode(False)
        def func():
            test_case.assertTrue(flow.is_grad_enabled())

        func()
        test_case.assertTrue(flow.is_grad_enabled())
 def func():
     test_case.assertFalse(flow.is_grad_enabled())
 def test_grad_mode(test_case):
     test_case.assertTrue(flow.is_grad_enabled())