class TestErrorWithInitFromStaticMode(unittest.TestCase): def setUp(self): self.program_translator = ProgramTranslator() self.x = np.random.randn(10, 32).astype('float32') def test_raise_error(self): # disable imperative paddle.enable_static() net = Net() self.program_translator.enable(True) with self.assertRaisesRegexp(RuntimeError, "only available in dynamic mode"): self.program_translator.get_output(net.forward, self.x) with self.assertRaisesRegexp(RuntimeError, "only available in dynamic mode"): self.program_translator.get_program(net.forward, self.x)
class TestEnableDeclarative(unittest.TestCase): def setUp(self): self.x = np.random.randn(30, 10, 32).astype('float32') self.weight = np.random.randn(32, 64).astype('float32') self.program_translator = ProgramTranslator() def test_raise_error(self): with fluid.dygraph.guard(): self.program_translator.enable(True) net = NetWithError() with self.assertRaises(ValueError): net(fluid.dygraph.to_variable(self.x)) def test_enable_disable_get_output(self): self.program_translator.enable(True) with fluid.dygraph.guard(): static_output = self.program_translator.get_output( simple_func, self.x, self.weight) self.program_translator.enable(False) with fluid.dygraph.guard(): dygraph_output = self.program_translator.get_output( simple_func, self.x, self.weight) self.assertTrue( np.allclose(static_output.numpy(), dygraph_output.numpy(), atol=1e-4)) def test_enable_disable_get_func(self): self.program_translator.enable(True) with fluid.dygraph.guard(): static_func = self.program_translator.get_func(simple_func) self.assertTrue(callable(static_func)) static_output = static_func(self.x, self.weight) self.assertTrue(isinstance(static_output, fluid.Variable)) self.program_translator.enable(False) with fluid.dygraph.guard(): dygraph_func = self.program_translator.get_func(simple_func) self.assertTrue(callable(dygraph_func)) dygraph_output = dygraph_func(self.x, self.weight) self.assertTrue(isinstance(dygraph_output, fluid.core.VarBase)) def test_enable_disable_get_program(self): self.program_translator.enable(True) static_output = self.program_translator.get_program( simple_func, self.x, self.weight) self.assertTrue(isinstance(static_output, tuple)) self.assertEqual(len(static_output), 4) self.assertTrue(isinstance(static_output[0], fluid.Program)) self.assertTrue(isinstance(static_output[1], fluid.Program)) # Check all inputs and outputs are Variable for var in static_output[2]: self.assertTrue(isinstance(var, fluid.Variable)) for var in static_output[3]: self.assertTrue(isinstance(var, fluid.Variable)) self.program_translator.enable(False) with fluid.dygraph.guard(): dygraph_output = self.program_translator.get_program( simple_func, self.x, self.weight) self.assertTrue(isinstance(dygraph_output, fluid.core.VarBase)) def test_enable_disable_declarative(self): self.program_translator.enable(True) with fluid.dygraph.guard(): static_output = decorated_simple_func(self.x, self.weight) self.program_translator.enable(False) with fluid.dygraph.guard(): dygraph_output = decorated_simple_func(self.x, self.weight) self.assertTrue( np.allclose(static_output.numpy(), dygraph_output.numpy(), atol=1e-4))