예제 #1
0
def create_fill_constant_node(name, value):
    func_code = "{} = paddle.fluid.layers.fill_constant(shape=[1], ".format(
        name)
    if isinstance(value, bool):
        func_code += "dtype='bool', value={}, name='{}')".format(value, name)
        return gast.parse(func_code).body[0]
    if isinstance(value, float):
        func_code += "dtype='float64', value={}, name='{}')".format(value, name)
        return gast.parse(func_code).body[0]

    if isinstance(value, int):
        func_code += "dtype='int64', value={}, name='{}')".format(value, name)
        return gast.parse(func_code).body[0]
예제 #2
0
    def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node,
                                               return_name,
                                               parent_node_of_return):
        i = index_in_list(stmt_list, node)
        if i < 0 or i >= len(stmt_list):
            return False
        if i == len(stmt_list) - 1:
            # No need to add, we consider this as added successfully
            return True

        if_stmt = gast.If(test=gast.UnaryOp(op=gast.Not(),
                                            operand=gast.Name(
                                                id=return_name,
                                                ctx=gast.Store(),
                                                annotation=None,
                                                type_comment=None)),
                          body=stmt_list[i + 1:],
                          orelse=[])

        stmt_list[i + 1:] = [if_stmt]

        # Here assume that the parent node of return is gast.If
        if isinstance(parent_node_of_return, gast.If):
            # Prepend control flow boolean nodes such as '__return@1 = False'
            node_str = "{} = _jst.create_bool_as_type({}, False)".format(
                return_name,
                ast_to_source_code(parent_node_of_return.test).strip())
            assign_false_node = gast.parse(node_str).body[0]

            stmt_list[i:i] = [assign_false_node]
        return True
예제 #3
0
    def test_nested_loop_vars(self):
        func = self.nested_for_loop_func
        test_func = inspect.getsource(func)
        gast_root = gast.parse(test_func)
        name_visitor = NameVisitor(gast_root)

        self.loop_var_names = [
            set(["j", "two"]),
            set(["i", "three", "b"]),
            set(["i", "j"])
        ]
        self.create_var_names = [set(), set(["b"]), set()]

        i = 0
        for node in gast.walk(gast_root):
            if isinstance(node, (gast.While, gast.For)):
                loop_var_names, create_var_names = name_visitor.get_loop_var_names(
                    node)
                self.assertEqual(
                    loop_var_names,
                    self.loop_var_names[i],
                    msg="loop_var_names : {}, \nexpected loop_var_names : {}".
                    format(loop_var_names, self.loop_var_names[i]))
                self.assertEqual(
                    create_var_names,
                    self.create_var_names[i],
                    msg=
                    "i = {}\ncreate_var_names : {}, \nexpected create_var_names : {}"
                    .format(i, create_var_names, self.create_var_names[i]))
                i += 1
예제 #4
0
    def _replace_pop(self, node):
        """
        Replace a pop statement for a list or dict.
        For example:

            list_a = [0,1,2,3,4]
            x = list_a.pop()  # --> convert_pop(list_a)
            y = list_a.pop(1) # --> convert_pop(list_a, 1)

            dict_a = {"red":0, "blue":1, "yellow":2}
            m = dict_a.pop("red")           # --> convert_pop(dict_a, "red")
            n = dict_a.pop("black", 3)      # --> convert_pop(dict_a, "black", 3)

        """
        assert isinstance(node, gast.Call)
        assert isinstance(node.func, gast.Attribute)

        target_node = node.func.value
        target_str = ast_to_source_code(target_node).strip()

        args_str = [ast_to_source_code(arg).strip() for arg in node.args]

        # NOTE(liym27):
        # 1. pop stmt for a list if len(args_str) == 0
        # 2. pop stmt for a list or dict if len(args_str) == 1
        # 3. pop stmt for a dict if len(args_str) == 2
        if len(args_str) <= 2:
            new_pop_str = "_jst.convert_pop({}, {})"\
                .format(target_str, ",".join(args_str))
            new_pop_node = gast.parse(new_pop_str).body[0].value
            return new_pop_node
        else:
            return node
예제 #5
0
 def _create_tensor_array(self, value_node):
     # Although `dtype='float32'`, other types such as `int32` can also be supported
     init_value = ast_to_source_code(value_node).strip()
     func_code = "paddle.tensor.create_array('float32', {})".format(
         init_value)
     func_node = gast.parse(func_code).body[0].value
     return func_node
def create_convert_shape_node(var_shape_node,
                              slice_node=None,
                              in_control_flow=False):
    assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript))

    if isinstance(var_shape_node, gast.Attribute):
        args = [ast_to_source_code(var_shape_node.value).strip()]
        # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index or gast.Constant
        # (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index or gast.Constant
        # In (1) case, we pass the number as 'idx' argument in convert_var_shape
        # In (2) case, we have to make it like `convert_var_shape(x)[slice]`
        if slice_node is not None and slice_is_num(slice_node):
            args.append(ast_to_source_code(slice_node.slice).strip())

        convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format(
            ",".join(args), in_control_flow)
        api_shape_node = gast.parse(convert_var_shape_func).body[0].value

        if slice_node is not None and not slice_is_num(slice_node):
            return gast.Subscript(
                value=api_shape_node, slice=slice_node.slice, ctx=gast.Load())
        return api_shape_node

    if isinstance(var_shape_node, gast.Subscript):
        result_node = copy.deepcopy(var_shape_node)
        result_node = create_convert_shape_node(result_node.value, result_node,
                                                in_control_flow)
        return result_node
예제 #7
0
    def _create_bool_op_node(self, nodes, api_type):
        '''
        NOTE(liym27):
           The arguments of function convert_logical_XX should be callable so that they can be run
          according to the actual order. In `convert_logical_and(lambda:x>1, lambda:y<1)`, `lambda:y<1`
          must be run after `lambda:x>1`, If `x>1` is False, `y<1` should NOT be run.
        '''
        assert len(
            nodes
        ) > 1, "The length of BoolOp should be at least 2, but received {}.".format(
            len(nodes))
        if len(nodes) > 2:
            # Creates logic_and/logic_or node recursively.
            pre_logic_node = self._create_bool_op_node(nodes[:2], api_type)
            if len(nodes[2:]) == 1:
                post_logic_node = nodes[2]
            else:
                post_logic_node = self._create_bool_op_node(
                    nodes[2:], api_type)
            nodes = [pre_logic_node] + [post_logic_node]

        args = [ast_to_source_code(child) for child in nodes]
        new_node_str = "paddle.jit.dy2static.convert_logical_{}(lambda:{}, lambda:{})".format(
            api_type, args[0], args[1])
        # NOTE: gast.parse return Module(body=[expr(...)])
        new_node = gast.parse(new_node_str).body[0].value
        return new_node
예제 #8
0
 def test_construct_node_wrapper(self):
     for func in test_funcs:
         test_source_code = inspect.getsource(func)
         ast_root = gast.parse(test_source_code)
         visitor = StaticAnalysisVisitor(ast_root)
         wrapper_root = visitor.get_node_wrapper_root()
         node_to_wrapper_map = visitor.get_node_to_wrapper_map()
         self._check_wrapper(wrapper_root, node_to_wrapper_map)
예제 #9
0
 def _to_array_write_node(self, node):
     assert isinstance(node, gast.Call)
     array = astor.to_source(gast.gast_to_ast(node.func.value))
     x = astor.to_source(gast.gast_to_ast(node.args[0]))
     i = "paddle.tensor.array_length({})".format(array)
     func_code = "paddle.tensor.array_write(x={}, i={}, array={})".format(
         x, i, array)
     return gast.parse(func_code).body[0].value
예제 #10
0
    def visit_Assert(self, node):
        convert_assert_node = gast.parse(
            'paddle.jit.dy2static.convert_assert({test}, {msg})'.format(
                test=ast_to_source_code(node.test),
                msg=ast_to_source_code(node.msg)
                if node.msg else "")).body[0].value

        return gast.Expr(value=convert_assert_node)
 def visit_Attribute(self, node):
     if node.attr == 'shape':
         args = ast_to_source_code(node.value).strip()
         convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape_simple({})".format(
             args)
         api_shape_node = gast.parse(convert_var_shape_func).body[0].value
         return api_shape_node
     return node
예제 #12
0
 def visit_UnaryOp(self, node):
     self.generic_visit(node)
     if isinstance(node.op, gast.Not):
         arg = ast_to_source_code(node.operand)
         new_node_str = "_jst.convert_logical_not({})".format(arg)
         # NOTE: gast.parse returns Module(body=[expr(value=...)])
         new_node = gast.parse(new_node_str).body[0].value
         return new_node
     return node
예제 #13
0
def code_gast_ast(source):
    """
    Transform source_code into gast.Node and modify it,
    then back to ast.Node.
    """
    source = textwrap.dedent(source)
    root = gast.parse(source)
    new_root = GastNodeTransformer(root).apply()
    ast_root = gast.gast_to_ast(new_root)
    return ast.dump(ast_root)
예제 #14
0
def create_convert_ifelse_node(return_name_ids,
                               pred,
                               true_func,
                               false_func,
                               is_if_expr=False):
    """
    Create `paddle.jit.dy2static.convert_ifelse(
            pred, true_fn, false_fn, true_args, false_args, return_vars)`
    to replace original `python if/else` statement.
    """
    def create_name_nodes(name_ids):
        if not name_ids:
            return gast.Tuple(elts=[], ctx=gast.Load())

        gast_names = [
            gast.Name(id=name_id,
                      ctx=gast.Load(),
                      annotation=None,
                      type_comment=None) for name_id in name_ids
        ]
        name_node = gast.Tuple(elts=gast_names, ctx=gast.Load())
        return name_node

    if is_if_expr:
        true_args = gast.Tuple(elts=[], ctx=gast.Load())
        false_args = gast.Tuple(elts=[], ctx=gast.Load())
        true_func_source = "lambda : {}".format(ast_to_source_code(true_func))
        false_func_source = "lambda : {}".format(
            ast_to_source_code(false_func))
    else:
        true_args = gast.Tuple(elts=true_func.args.args, ctx=gast.Load())
        false_args = gast.Tuple(elts=false_func.args.args, ctx=gast.Load())
        true_func_source = true_func.name
        false_func_source = false_func.name

    return_vars = create_name_nodes(return_name_ids)

    convert_ifelse_layer = gast.parse(
        '_jst.convert_ifelse('
        '{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})'
        .format(pred=ast_to_source_code(pred),
                true_fn=true_func_source,
                false_fn=false_func_source,
                true_args=ast_to_source_code(true_args),
                false_args=ast_to_source_code(false_args),
                return_vars=ast_to_source_code(return_vars))).body[0].value

    if return_name_ids:
        _, cond_node = create_assign_node(return_name_ids,
                                          convert_ifelse_layer)
    else:  # No variables can be returned if no assign statement in if.body.
        cond_node = gast.Expr(value=convert_ifelse_layer)

    return cond_node
예제 #15
0
    def visit_Call(self, node):
        self.generic_visit(node)
        func_str = ast_to_source_code(node.func).strip()
        if func_str in self._castable_type and len(node.args) > 0:
            args_str = ast_to_source_code(node.args[0]).strip()
            new_func_str = "paddle.jit.dy2static.convert_var_dtype({}, '{}')".format(
                args_str, func_str)
            new_node = gast.parse(new_func_str).body[0].value
            return new_node

        return node
예제 #16
0
 def test_loop_vars(self):
     for i in range(len(self.loop_funcs)):
         func = self.loop_funcs[i]
         test_func = inspect.getsource(func)
         gast_root = gast.parse(test_func)
         name_visitor = NameVisitor(gast_root)
         for node in gast.walk(gast_root):
             if isinstance(node, (gast.While, gast.For)):
                 loop_var_names, create_var_names = name_visitor.get_loop_var_names(
                     node)
                 self.assertEqual(loop_var_names, self.loop_var_names[i])
                 self.assertEqual(create_var_names,
                                  self.create_var_names[i])
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
    eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format(
        api_shape_name)
    args = [attr_shape_name, eval_exist_func]

    if slice_node is not None and slice_is_num(slice_node):
        args.append(ast_to_source_code(slice_node.slice).strip())
    choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format(
        ",".join(args))
    choose_shape_node = gast.parse(choose_shape_func).body[0].value
    if slice_node is not None and not slice_is_num(slice_node):
        return gast.Subscript(
            value=choose_shape_node, slice=slice_node.slice, ctx=gast.Load())
    return choose_shape_node
예제 #18
0
    def test_paddle_api(self):
        code = """
            def foo(x):
                if fluid.layers.shape(x)[0] > 16:
                    x = x + 1
                return x
        """
        code = textwrap.dedent(code)
        node = gast.parse(code)
        static_analysis_visitor = StaticAnalysisVisitor(node)
        test_node = node.body[0].body[0].test

        self.assertTrue(
            is_control_flow_to_transform(test_node, static_analysis_visitor))
예제 #19
0
    def test_shape_with_andOr(self):
        code = """
            def foo(x):
                batch_size = fluid.layers.shape(x)
                if x is not None and batch_size[0] > 16 or 2 > 1:
                    x = x + 1
                return x
        """
        code = textwrap.dedent(code)
        node = gast.parse(code)
        static_analysis_visitor = StaticAnalysisVisitor(node)
        test_node = node.body[0].body[1].test

        self.assertTrue(
            is_control_flow_to_transform(test_node, static_analysis_visitor))
예제 #20
0
    def test_with_node_var_type_map(self):
        node = gast.parse("x > 1")
        node_test = node.body[0].value

        # if x is a Tensor
        var_name_to_type = {"x": {NodeVarType.TENSOR}}

        self.assertTrue(
            is_control_flow_to_transform(node_test,
                                         var_name_to_type=var_name_to_type))

        # if x is not a Tensor
        var_name_to_type = {"x": {NodeVarType.NUMPY_NDARRAY}}
        self.assertFalse(
            is_control_flow_to_transform(node_test,
                                         var_name_to_type=var_name_to_type))
예제 #21
0
    def get_code(self, dygraph_func):
        """
        Returns the translated static function string code from dygraph function.

        Args:
            dygraph_func (callable): the dygraph function.

        Returns:
            str: the string code of translated static function.

        Examples:
            .. code-block:: python

                import paddle


                def func(x):
                    if paddle.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v


                prog_trans = paddle.jit.ProgramTranslator()

                code = prog_trans.get_code(func)
                print(type(code)) # <class 'str'>

        """
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_code"
        # Gets AST from dygraph function

        unwrap_func = unwrap(dygraph_func)
        raw_code = inspect.getsource(unwrap_func)
        code = textwrap.dedent(raw_code)
        root = gast.parse(code)

        # Transform AST
        dygraph_to_static = DygraphToStaticAst()
        root_wrapper = dygraph_to_static.get_static_ast(root)

        # Get source_code
        source_code = ast_to_source_code(root_wrapper.node)
        return source_code
예제 #22
0
    def test_log_api(self):
        # test api for CI Converage
        logging_utils.set_verbosity(1, True)

        logging_utils.warn("warn")
        logging_utils.error("error")

        logging_utils.log(1, "log level 1")
        logging_utils.log(2, "log level 2")

        source_code = "x = 3"
        ast_code = gast.parse(source_code)
        logging_utils.set_code_level(1, True)
        logging_utils.log_transformed_code(1, ast_code, "TestTransformer")
        logging_utils.set_code_level(logging_utils.LOG_AllTransformer, True)
        logging_utils.log_transformed_code(logging_utils.LOG_AllTransformer,
                                           ast_code, "TestTransformer")
예제 #23
0
    def visit_Call(self, node):
        self.generic_visit(node)

        if self._no_need_convert_call(node):
            return node

        func_str = ast_to_source_code(node.func).strip()

        # NOTE(liym27): Don't convert `pad.set_trace` even if the convertion doesn't work finally, because
        # it is clearer to see where it is called from.
        if PDB_SET in func_str:
            return node

        new_func_str = "_jst.convert_call({})".format(func_str)
        new_func_ast = gast.parse(new_func_str).body[0].value
        node.func = new_func_ast

        return node
예제 #24
0
    def _transform_slice_to_tensor_write(self, node):
        assert isinstance(node, gast.Assign)
        target_node = node.targets[0]

        target_name = target_node.value.id
        slice_node = target_node.slice

        if isinstance(slice_node, gast.Slice):
            pass
        elif slice_is_num(target_node):
            value_code = ast_to_source_code(node.value)
            i = "paddle.cast(" \
                "x=_jst.to_static_variable({})," \
                "dtype='int64')".format(ast_to_source_code(slice_node))
            assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \
                .format(target_name, value_code, i, target_name)
            assign_node = gast.parse(assign_code).body[0]
        return assign_node
예제 #25
0
def create_and_update_origin_info_map(transformed_node,
                                      static_func,
                                      is_global=True):
    """
    Creates a original information map between transformed static function and original dygraph function.

    Args:
        transformed_node(gast.AST): The AST node of transformed dygraph function with attached source information of original dygraph function.
        static_func(Callable): The static function transformed by dygraph function corresponding to transformed_node.

    Returns:
        The original information map.
    """

    origin_info_map = {}
    static_source = inspect.getsource(static_func)
    static_node = gast.parse(static_source)
    static_node = attach_origin_info(static_node, static_func)

    for t_node, s_node in ast_walk(transformed_node, static_node):
        assert type(t_node) == type(s_node), \
            "The node types should be the same, but received type(t_node) is {}, and type(s_node) is {}." \
                .format(type(t_node), type(s_node))
        dygraph_info = getattr(t_node, ORIGI_INFO, None)
        static_info = getattr(s_node, ORIGI_INFO, None)

        if dygraph_info is None or static_info is None:
            continue
        static_loc = static_info.location.line_location
        exist_origin_info = origin_info_map.get(static_loc)

        if exist_origin_info is not None:
            if exist_origin_info.location.lineno >= dygraph_info.location.lineno:
                continue
            if exist_origin_info.location.col_offset <= dygraph_info.location.col_offset:
                continue

        origin_info_map[static_loc] = dygraph_info

    global_origin_info_map.update(origin_info_map)
    if is_global:
        return global_origin_info_map

    return origin_info_map
예제 #26
0
    def test_log_transformed_code(self):
        source_code = "x = 3"
        ast_code = gast.parse(source_code)

        stream = io.StringIO()
        log = self.translator_logger.logger
        stdout_handler = logging.StreamHandler(stream)
        log.addHandler(stdout_handler)

        with mock.patch.object(sys, 'stdout', stream):
            paddle.jit.set_code_level(1)
            logging_utils.log_transformed_code(1, ast_code,
                                               "BasicApiTransformer")

            paddle.jit.set_code_level()
            logging_utils.log_transformed_code(logging_utils.LOG_AllTransformer,
                                               ast_code, "All Transformers")

        self.assertIn(source_code, stream.getvalue())
예제 #27
0
    def visit_Call(self, node):
        self.generic_visit(node)
        if not is_grad_api_node(node):
            return node

        dygraph_grad_parameters = [
            "outputs", "inputs", "grad_outputs", "retain_graph",
            "create_graph", "only_inputs", "allow_unused", "no_grad_vars"
        ]
        to_static_grad_param = {
            "outputs": "targets",
            "inputs": "inputs",
            "grad_outputs": "target_gradients",
            "no_grad_vars": "no_grad_set"
        }
        static_keywords = []

        for kw in node.keywords:
            if kw.arg not in dygraph_grad_parameters or kw.arg not in to_static_grad_param:
                warnings.warn(
                    "paddle.grad has unsupported parameter in jit: " + kw.arg +
                    ", jit will discard it")
                continue
            dygraph_grad_parameters.remove(kw.arg)
            kw.arg = to_static_grad_param[kw.arg]
            static_keywords.append(kw)

        for i in range(len(node.args)):
            arg_name = dygraph_grad_parameters[i]
            if arg_name not in to_static_grad_param:
                warnings.warn(
                    "paddle.grad has unsupported parameter in jit: " + kw.arg +
                    ", jit will discard it")
                continue
            kw = gast.keyword(arg=to_static_grad_param[arg_name],
                              value=node.args[i])
            static_keywords.append(kw)

        node.func = gast.parse('paddle.static.gradients').body[0].value
        node.keywords = static_keywords
        node.args = []
        return node
예제 #28
0
    def _convert(self, func):
        """
        Converts dygraph function into static function. For two functions with same dedent code,
        the second function will reuse the transformed ast node of previous one.

        For example:
            # A.py
            def foo(x, y):
                z = x + y
                return z

            # B.py
            def foo(x, y):
                z = x + y
                return z

        If the conversion of A.foo happens after B.foo, it will reuse the transformed ast node of B.foo
        to speed up the conversion.
        """
        # Note: In Python2, it will raise OSError when inspect function
        # with decorator directly and function.__wrapped__ holds the actual function.
        func = unwrap(func)
        source_code = func_to_source_code(func)

        # TODO(liym27):
        #  Consider this case: source_code in self._code_to_ast_caches,
        #  but actually they are methods in different classes.
        #  Maybe use (__class__, source_code) as key
        if source_code in self._code_to_ast_caches:
            root_wrapper = self._code_to_ast_caches[source_code]
        else:
            root = gast.parse(source_code)
            root = attach_origin_info(root, func)
            root_wrapper = self._dygraph_to_static.get_static_ast(root)
            self._code_to_ast_caches[source_code] = root_wrapper

        # Get static function from AST
        static_func, file_name = ast_to_func(root_wrapper.node, func)

        create_and_update_origin_info_map(root_wrapper.node, static_func)
        return static_func
예제 #29
0
    def test_var_env(self):

        for i, func in enumerate(test_funcs):
            var_type = result_var_type[i]
            test_source_code = inspect.getsource(func)
            ast_root = gast.parse(test_source_code)
            print(gast.dump(ast_root))
            visitor = StaticAnalysisVisitor(ast_root)
            var_env = visitor.get_var_env()

            # There must be 1 sub scope for the test function
            self.assertEqual(1, len(var_env.cur_scope.sub_scopes))
            var_env.cur_scope = var_env.cur_scope.sub_scopes[0]

            scope_var_type = var_env.get_scope_var_type()
            print(scope_var_type)
            self.assertEqual(len(scope_var_type), len(var_type))
            for name in scope_var_type:
                print("Test var name %s" % (name))
                self.assertTrue(name in var_type)
                self.assertEqual(scope_var_type[name], var_type[name])
예제 #30
0
    def visit_Compare(self, node):
        self.generic_visit(node)
        left_str = ast_to_source_code(node.left).strip()
        if left_str.startswith("_jst.convert_var_shape"):
            # check left and comparators are all converted var shape
            compare_arg_strs = left_str
            for i, comparator in enumerate(node.comparators):
                comparator_str = ast_to_source_code(comparator).strip()
                if not comparator_str.startswith("_jst.convert_var_shape"):
                    return node
                op_str = cmpop_node_to_str(node.ops[i])
                compare_arg_strs += (", '" + op_str + "', " + comparator_str)

            # Now all left and comparators are converted shape
            # Replace some comparsion operation because of difference between
            # Python and Paddle
            new_node_str = "_jst.convert_shape_compare({})".format(
                compare_arg_strs)
            new_node = gast.parse(new_node_str).body[0].value
            return new_node
        return node