Пример #1
0
class TestErrorWithInitFromStaticMode(unittest.TestCase):
    def setUp(self):
        self.program_translator = ProgramTranslator()
        self.x = np.random.randn(10, 32).astype('float32')

    def test_raise_error(self):
        # disable imperative
        paddle.enable_static()
        net = Net()

        self.program_translator.enable(True)
        with self.assertRaisesRegexp(RuntimeError,
                                     "only available in dynamic mode"):
            self.program_translator.get_output(net.forward, self.x)

        with self.assertRaisesRegexp(RuntimeError,
                                     "only available in dynamic mode"):
            self.program_translator.get_program(net.forward, self.x)
Пример #2
0
class TestEnableDeclarative(unittest.TestCase):
    def setUp(self):
        self.x = np.random.randn(30, 10, 32).astype('float32')
        self.weight = np.random.randn(32, 64).astype('float32')
        self.program_translator = ProgramTranslator()

    def test_raise_error(self):
        with fluid.dygraph.guard():
            self.program_translator.enable(True)
            net = NetWithError()
            with self.assertRaises(ValueError):
                net(fluid.dygraph.to_variable(self.x))

    def test_enable_disable_get_output(self):
        self.program_translator.enable(True)
        with fluid.dygraph.guard():
            static_output = self.program_translator.get_output(
                simple_func, self.x, self.weight)

        self.program_translator.enable(False)
        with fluid.dygraph.guard():
            dygraph_output = self.program_translator.get_output(
                simple_func, self.x, self.weight)
            self.assertTrue(
                np.allclose(static_output.numpy(),
                            dygraph_output.numpy(),
                            atol=1e-4))

    def test_enable_disable_get_func(self):

        self.program_translator.enable(True)
        with fluid.dygraph.guard():
            static_func = self.program_translator.get_func(simple_func)
            self.assertTrue(callable(static_func))
            static_output = static_func(self.x, self.weight)
            self.assertTrue(isinstance(static_output, fluid.Variable))

        self.program_translator.enable(False)
        with fluid.dygraph.guard():
            dygraph_func = self.program_translator.get_func(simple_func)
            self.assertTrue(callable(dygraph_func))
            dygraph_output = dygraph_func(self.x, self.weight)
            self.assertTrue(isinstance(dygraph_output, fluid.core.VarBase))

    def test_enable_disable_get_program(self):

        self.program_translator.enable(True)
        static_output = self.program_translator.get_program(
            simple_func, self.x, self.weight)
        self.assertTrue(isinstance(static_output, tuple))
        self.assertEqual(len(static_output), 4)
        self.assertTrue(isinstance(static_output[0], fluid.Program))
        self.assertTrue(isinstance(static_output[1], fluid.Program))
        # Check all inputs and outputs are Variable
        for var in static_output[2]:
            self.assertTrue(isinstance(var, fluid.Variable))

        for var in static_output[3]:
            self.assertTrue(isinstance(var, fluid.Variable))

        self.program_translator.enable(False)
        with fluid.dygraph.guard():
            dygraph_output = self.program_translator.get_program(
                simple_func, self.x, self.weight)
            self.assertTrue(isinstance(dygraph_output, fluid.core.VarBase))

    def test_enable_disable_declarative(self):

        self.program_translator.enable(True)
        with fluid.dygraph.guard():
            static_output = decorated_simple_func(self.x, self.weight)

        self.program_translator.enable(False)
        with fluid.dygraph.guard():
            dygraph_output = decorated_simple_func(self.x, self.weight)
            self.assertTrue(
                np.allclose(static_output.numpy(),
                            dygraph_output.numpy(),
                            atol=1e-4))