コード例 #1
0
    def test_no_grad(test_case):
        with flow.no_grad():
            test_case.assertFalse(flow.is_grad_enabled())
        test_case.assertTrue(flow.is_grad_enabled())

        @flow.no_grad()
        def func():
            test_case.assertFalse(flow.is_grad_enabled())

        func()
        test_case.assertTrue(flow.is_grad_enabled())
コード例 #2
0
ファイル: test_autograd_mode.py プロジェクト: zzk0/oneflow
    def test_grad_enable(test_case):
        with flow.grad_enable():
            test_case.assertTrue(flow.is_grad_enabled())
        test_case.assertTrue(flow.is_grad_enabled())

        @flow.grad_enable()
        def func():
            test_case.assertTrue(flow.is_grad_enabled())

        func()
        test_case.assertTrue(flow.is_grad_enabled())
コード例 #3
0
    def test_set_grad_enabled(test_case):
        with flow.set_grad_enabled(True):
            test_case.assertTrue(flow.is_grad_enabled())
        test_case.assertTrue(flow.is_grad_enabled())

        @flow.set_grad_enabled(True)
        def func():
            test_case.assertTrue(flow.is_grad_enabled())

        func()
        test_case.assertTrue(flow.is_grad_enabled())

        with flow.set_grad_enabled(False):
            test_case.assertFalse(flow.is_grad_enabled())
        test_case.assertTrue(flow.is_grad_enabled())

        @flow.set_grad_enabled(False)
        def func():
            test_case.assertFalse(flow.is_grad_enabled())

        func()
        test_case.assertTrue(flow.is_grad_enabled())
コード例 #4
0
    def test_inference_mode(test_case):
        with flow.inference_mode(True):
            test_case.assertFalse(flow.is_grad_enabled())
        test_case.assertTrue(flow.is_grad_enabled())

        @flow.inference_mode(True)
        def func():
            test_case.assertFalse(flow.is_grad_enabled())

        func()
        test_case.assertTrue(flow.is_grad_enabled())

        with flow.inference_mode(False):
            test_case.assertTrue(flow.is_grad_enabled())
        test_case.assertTrue(flow.is_grad_enabled())

        @flow.inference_mode(False)
        def func():
            test_case.assertTrue(flow.is_grad_enabled())

        func()
        test_case.assertTrue(flow.is_grad_enabled())
コード例 #5
0
 def func():
     test_case.assertFalse(flow.is_grad_enabled())
コード例 #6
0
 def test_grad_mode(test_case):
     test_case.assertTrue(flow.is_grad_enabled())