def test_no_grad(test_case): with flow.no_grad(): test_case.assertFalse(flow.is_grad_enabled()) test_case.assertTrue(flow.is_grad_enabled()) @flow.no_grad() def func(): test_case.assertFalse(flow.is_grad_enabled()) func() test_case.assertTrue(flow.is_grad_enabled())
def test_grad_enable(test_case): with flow.grad_enable(): test_case.assertTrue(flow.is_grad_enabled()) test_case.assertTrue(flow.is_grad_enabled()) @flow.grad_enable() def func(): test_case.assertTrue(flow.is_grad_enabled()) func() test_case.assertTrue(flow.is_grad_enabled())
def test_set_grad_enabled(test_case): with flow.set_grad_enabled(True): test_case.assertTrue(flow.is_grad_enabled()) test_case.assertTrue(flow.is_grad_enabled()) @flow.set_grad_enabled(True) def func(): test_case.assertTrue(flow.is_grad_enabled()) func() test_case.assertTrue(flow.is_grad_enabled()) with flow.set_grad_enabled(False): test_case.assertFalse(flow.is_grad_enabled()) test_case.assertTrue(flow.is_grad_enabled()) @flow.set_grad_enabled(False) def func(): test_case.assertFalse(flow.is_grad_enabled()) func() test_case.assertTrue(flow.is_grad_enabled())
def test_inference_mode(test_case): with flow.inference_mode(True): test_case.assertFalse(flow.is_grad_enabled()) test_case.assertTrue(flow.is_grad_enabled()) @flow.inference_mode(True) def func(): test_case.assertFalse(flow.is_grad_enabled()) func() test_case.assertTrue(flow.is_grad_enabled()) with flow.inference_mode(False): test_case.assertTrue(flow.is_grad_enabled()) test_case.assertTrue(flow.is_grad_enabled()) @flow.inference_mode(False) def func(): test_case.assertTrue(flow.is_grad_enabled()) func() test_case.assertTrue(flow.is_grad_enabled())
def func(): test_case.assertFalse(flow.is_grad_enabled())
def test_grad_mode(test_case): test_case.assertTrue(flow.is_grad_enabled())