예제 #1
0
def create_convert_ifelse_node(return_name_ids,
                               pred,
                               true_func,
                               false_func,
                               is_if_expr=False):
    """
    Create `paddle.jit.dy2static.convert_ifelse(
            pred, true_fn, false_fn, true_args, false_args, return_vars)`
    to replace original `python if/else` statement.
    """
    def create_name_nodes(name_ids):
        if not name_ids:
            return gast.Tuple(elts=[], ctx=gast.Load())

        gast_names = [
            gast.Name(id=name_id,
                      ctx=gast.Load(),
                      annotation=None,
                      type_comment=None) for name_id in name_ids
        ]
        name_node = gast.Tuple(elts=gast_names, ctx=gast.Load())
        return name_node

    if is_if_expr:
        true_args = gast.Tuple(elts=[], ctx=gast.Load())
        false_args = gast.Tuple(elts=[], ctx=gast.Load())
        true_func_source = "lambda : {}".format(ast_to_source_code(true_func))
        false_func_source = "lambda : {}".format(
            ast_to_source_code(false_func))
    else:
        true_args = gast.Tuple(elts=true_func.args.args, ctx=gast.Load())
        false_args = gast.Tuple(elts=false_func.args.args, ctx=gast.Load())
        true_func_source = true_func.name
        false_func_source = false_func.name

    return_vars = create_name_nodes(return_name_ids)

    convert_ifelse_layer = gast.parse(
        '_jst.convert_ifelse('
        '{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})'
        .format(pred=ast_to_source_code(pred),
                true_fn=true_func_source,
                false_fn=false_func_source,
                true_args=ast_to_source_code(true_args),
                false_args=ast_to_source_code(false_args),
                return_vars=ast_to_source_code(return_vars))).body[0].value

    if return_name_ids:
        _, cond_node = create_assign_node(return_name_ids,
                                          convert_ifelse_layer)
    else:  # No variables can be returned if no assign statement in if.body.
        cond_node = gast.Expr(value=convert_ifelse_layer)

    return cond_node
예제 #2
0
    def create_name_nodes(name_ids):
        if not name_ids:
            return gast.Tuple(elts=[], ctx=gast.Load())

        gast_names = [
            gast.Name(id=name_id,
                      ctx=gast.Load(),
                      annotation=None,
                      type_comment=None) for name_id in name_ids
        ]
        name_node = gast.Tuple(elts=gast_names, ctx=gast.Load())
        return name_node
예제 #3
0
    def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name,
                                     max_return_length, parent_node_of_return):

        assert max_return_length >= 0, "Input illegal max_return_length"
        i = index_in_list(stmt_list, return_node)
        if i == -1:
            return False

        assign_nodes = []
        # Here assume that the parent node of return is gast.If
        if isinstance(parent_node_of_return, gast.If):
            # Prepend control flow boolean nodes such as '__return@1 = True'
            node_str = "{} = _jst.create_bool_as_type({}, True)".format(
                return_name,
                ast_to_source_code(parent_node_of_return.test).strip())

            assign_true_node = gast.parse(node_str).body[0]
            assign_nodes.append(assign_true_node)

        cur_func_node = self.function_def[-1]
        return_length = get_return_size(return_node)
        if return_length < max_return_length:
            # In this case we should append RETURN_NO_VALUE placeholder
            #
            # max_return_length must be >= 1 here because return_length will be
            # 0 at least.
            if self.return_value_name[cur_func_node] is None:
                self.return_value_name[cur_func_node] = unique_name.generate(
                    RETURN_VALUE_PREFIX)

            no_value_names = [
                unique_name.generate(RETURN_NO_VALUE_VAR_NAME)
                for j in range(max_return_length - return_length)
            ]
            self.return_no_value_name[cur_func_node].extend(no_value_names)

            # Handle tuple/non-tuple case
            if max_return_length == 1:
                assign_nodes.append(
                    gast.Assign(targets=[
                        gast.Name(id=self.return_value_name[cur_func_node],
                                  ctx=gast.Store(),
                                  annotation=None,
                                  type_comment=None)
                    ],
                                value=gast.Name(id=no_value_names[0],
                                                ctx=gast.Load(),
                                                annotation=None,
                                                type_comment=None)))
            else:
                # max_return_length > 1 which means we should assign tuple
                fill_tuple = [
                    gast.Name(id=n,
                              ctx=gast.Load(),
                              annotation=None,
                              type_comment=None) for n in no_value_names
                ]
                if return_node.value is not None:
                    if isinstance(return_node.value, gast.Tuple):
                        fill_tuple[:0] = return_node.value.elts
                    else:
                        fill_tuple.insert(0, return_node.value)

                assign_nodes.append(
                    gast.Assign(targets=[
                        gast.Name(id=self.return_value_name[cur_func_node],
                                  ctx=gast.Store(),
                                  annotation=None,
                                  type_comment=None)
                    ],
                                value=gast.Tuple(elts=fill_tuple,
                                                 ctx=gast.Load())))
        else:
            # In this case we should NOT append RETURN_NO_VALUE placeholder
            if return_node.value is not None:
                cur_func_node = self.function_def[-1]
                if self.return_value_name[cur_func_node] is None:
                    self.return_value_name[
                        cur_func_node] = unique_name.generate(
                            RETURN_VALUE_PREFIX)

                assign_nodes.append(
                    gast.Assign(targets=[
                        gast.Name(id=self.return_value_name[cur_func_node],
                                  ctx=gast.Store(),
                                  annotation=None,
                                  type_comment=None)
                    ],
                                value=return_node.value))

        stmt_list[i:] = assign_nodes
        return True
예제 #4
0
    def visit_FunctionDef(self, node):
        self.function_def.append(node)
        self.return_value_name[node] = None
        self.return_name[node] = []
        self.return_no_value_name[node] = []

        self.pre_analysis = ReturnAnalysisVisitor(node)
        max_return_length = self.pre_analysis.get_func_max_return_length(node)
        while self.pre_analysis.get_func_return_count(node) > 1:
            self.generic_visit(node)
            self.pre_analysis = ReturnAnalysisVisitor(node)

        if max_return_length == 0:
            self.function_def.pop()
            return node

        # Prepend initialization of final return and append final return statement
        value_name = self.return_value_name[node]
        if value_name is not None:
            node.body.append(
                gast.Return(value=gast.Name(id=value_name,
                                            ctx=gast.Load(),
                                            annotation=None,
                                            type_comment=None)))
            init_names = [
                unique_name.generate(RETURN_VALUE_INIT_NAME)
                for i in range(max_return_length)
            ]
            assign_zero_nodes = [
                create_fill_constant_node(iname, 0.0) for iname in init_names
            ]
            if len(init_names) == 1:
                return_value_nodes = gast.Name(id=init_names[0],
                                               ctx=gast.Load(),
                                               annotation=None,
                                               type_comment=None)
            else:
                # We need to initialize return value as a tuple because control
                # flow requires some inputs or outputs have same structure
                return_value_nodes = gast.Tuple(elts=[
                    gast.Name(id=iname,
                              ctx=gast.Load(),
                              annotation=None,
                              type_comment=None) for iname in init_names
                ],
                                                ctx=gast.Load())
            assign_return_value_node = gast.Assign(targets=[
                gast.Name(id=value_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                                                   value=return_value_nodes)
            node.body.insert(0, assign_return_value_node)
            node.body[:0] = assign_zero_nodes

        # Prepend no value placeholders
        for name in self.return_no_value_name[node]:
            assign_no_value_node = create_fill_constant_node(
                name, RETURN_NO_VALUE_MAGIC_NUM)
            node.body.insert(0, assign_no_value_node)

        self.function_def.pop()
        return node