def test_with_node_var_type_map(self):
        node = gast.parse("x > 1")
        node_test = node.body[0].value

        # if x is a Tensor
        var_name_to_type = {"x": {NodeVarType.TENSOR}}

        self.assertTrue(
            is_control_flow_to_transform(
                node_test, var_name_to_type=var_name_to_type))

        # if x is not a Tensor
        var_name_to_type = {"x": {NodeVarType.NUMPY_NDARRAY}}
        self.assertFalse(
            is_control_flow_to_transform(
                node_test, var_name_to_type=var_name_to_type))
    def test_paddle_api(self):
        code = """
            def foo(x):
                if fluid.layers.shape(x)[0] > 16:
                    x = x + 1
                return x
        """
        code = textwrap.dedent(code)
        node = gast.parse(code)
        static_analysis_visitor = StaticAnalysisVisitor(node)
        test_node = node.body[0].body[0].test

        self.assertTrue(
            is_control_flow_to_transform(test_node, static_analysis_visitor))
    def test_shape_with_andOr(self):
        code = """
            def foo(x):
                batch_size = fluid.layers.shape(x)
                if x is not None and batch_size[0] > 16 or 2 > 1:
                    x = x + 1
                return x
        """
        code = textwrap.dedent(code)
        node = gast.parse(code)
        static_analysis_visitor = StaticAnalysisVisitor(node)
        test_node = node.body[0].body[1].test

        self.assertTrue(
            is_control_flow_to_transform(test_node, static_analysis_visitor))
    def test_paddle_api_with_andOr(self):
        code_or = """
            def foo(x):
                if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None :
                    x = x + 1
                return x
        """

        code_and = """
            def foo(x):
                if 2 > 1 and fluid.layers.shape(x)[0] > 16 and x is not None :
                    x = x + 1
                return x
        """
        for code in [code_or, code_and]:
            code = textwrap.dedent(code)
            node = gast.parse(code)
            static_analysis_visitor = StaticAnalysisVisitor(node)
            test_node = node.body[0].body[0].test

            self.assertTrue(
                is_control_flow_to_transform(test_node,
                                             static_analysis_visitor))
Example #5
0
 def visit_For(self, node):
     self.generic_visit(node)
     if is_control_flow_to_transform(node, self.static_analysis_visitor,
                                     self.scope_var_type_dict):
         self._transform_list_append_in_control_flow(node)
     return node
Example #6
0
 def is_control_flow_loop(self, node):
     need_transform = is_control_flow_to_transform(
         node, self.static_analysis_visitor)
     return need_transform
Example #7
0
 def visit_While(self, node):
     self.generic_visit(node)
     if is_control_flow_to_transform(node, self.scope_var_type_dict):
         self._transform_list_append_in_control_flow(node)
     return node
 def test_raise_error(self):
     node = "a + b"
     with self.assertRaises(Exception) as e:
         self.assertRaises(TypeError, is_control_flow_to_transform(node))
     self.assertTrue(
         "The type of input node must be gast.AST" in str(e.exception))
    def test_if_with_or(self):
        node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0")
        node_test = node.body[0].value

        self.assertTrue(is_control_flow_to_transform(node_test))
Example #10
0
    def test_if_with_and(self):
        node = gast.parse("x and 1 < x.numpy()[1]")
        node_test = node.body[0].value

        self.assertTrue(is_control_flow_to_transform(node_test))
Example #11
0
    def test_is_None4(self):
        node = gast.parse("fluid.layers.sum(x) and 2>1")
        node_test = node.body[0].value

        self.assertTrue(is_control_flow_to_transform(node_test))
Example #12
0
 def test_expr2(self):
     # x is a Tensor.
     node = gast.parse("a + x.numpy()")
     node_test = node.body[0].value
     self.assertTrue(is_control_flow_to_transform(node_test))
Example #13
0
    def check_false_case(self, code):
        code = textwrap.dedent(code)
        node = gast.parse(code)
        node_test = node.body[0].value

        self.assertFalse(is_control_flow_to_transform(node_test))