def create_convert_ifelse_node(return_name_ids, pred, true_func, false_func, is_if_expr=False): """ Create `paddle.jit.dy2static.convert_ifelse( pred, true_fn, false_fn, true_args, false_args, return_vars)` to replace original `python if/else` statement. """ def create_name_nodes(name_ids): if not name_ids: return gast.Tuple(elts=[], ctx=gast.Load()) gast_names = [ gast.Name(id=name_id, ctx=gast.Load(), annotation=None, type_comment=None) for name_id in name_ids ] name_node = gast.Tuple(elts=gast_names, ctx=gast.Load()) return name_node if is_if_expr: true_args = gast.Tuple(elts=[], ctx=gast.Load()) false_args = gast.Tuple(elts=[], ctx=gast.Load()) true_func_source = "lambda : {}".format(ast_to_source_code(true_func)) false_func_source = "lambda : {}".format( ast_to_source_code(false_func)) else: true_args = gast.Tuple(elts=true_func.args.args, ctx=gast.Load()) false_args = gast.Tuple(elts=false_func.args.args, ctx=gast.Load()) true_func_source = true_func.name false_func_source = false_func.name return_vars = create_name_nodes(return_name_ids) convert_ifelse_layer = gast.parse( '_jst.convert_ifelse(' '{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})' .format(pred=ast_to_source_code(pred), true_fn=true_func_source, false_fn=false_func_source, true_args=ast_to_source_code(true_args), false_args=ast_to_source_code(false_args), return_vars=ast_to_source_code(return_vars))).body[0].value if return_name_ids: _, cond_node = create_assign_node(return_name_ids, convert_ifelse_layer) else: # No variables can be returned if no assign statement in if.body. cond_node = gast.Expr(value=convert_ifelse_layer) return cond_node
def create_name_nodes(name_ids): if not name_ids: return gast.Tuple(elts=[], ctx=gast.Load()) gast_names = [ gast.Name(id=name_id, ctx=gast.Load(), annotation=None, type_comment=None) for name_id in name_ids ] name_node = gast.Tuple(elts=gast_names, ctx=gast.Load()) return name_node
def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name, max_return_length, parent_node_of_return): assert max_return_length >= 0, "Input illegal max_return_length" i = index_in_list(stmt_list, return_node) if i == -1: return False assign_nodes = [] # Here assume that the parent node of return is gast.If if isinstance(parent_node_of_return, gast.If): # Prepend control flow boolean nodes such as '__return@1 = True' node_str = "{} = _jst.create_bool_as_type({}, True)".format( return_name, ast_to_source_code(parent_node_of_return.test).strip()) assign_true_node = gast.parse(node_str).body[0] assign_nodes.append(assign_true_node) cur_func_node = self.function_def[-1] return_length = get_return_size(return_node) if return_length < max_return_length: # In this case we should append RETURN_NO_VALUE placeholder # # max_return_length must be >= 1 here because return_length will be # 0 at least. if self.return_value_name[cur_func_node] is None: self.return_value_name[cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) no_value_names = [ unique_name.generate(RETURN_NO_VALUE_VAR_NAME) for j in range(max_return_length - return_length) ] self.return_no_value_name[cur_func_node].extend(no_value_names) # Handle tuple/non-tuple case if max_return_length == 1: assign_nodes.append( gast.Assign(targets=[ gast.Name(id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Name(id=no_value_names[0], ctx=gast.Load(), annotation=None, type_comment=None))) else: # max_return_length > 1 which means we should assign tuple fill_tuple = [ gast.Name(id=n, ctx=gast.Load(), annotation=None, type_comment=None) for n in no_value_names ] if return_node.value is not None: if isinstance(return_node.value, gast.Tuple): fill_tuple[:0] = return_node.value.elts else: fill_tuple.insert(0, return_node.value) assign_nodes.append( gast.Assign(targets=[ gast.Name(id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Tuple(elts=fill_tuple, ctx=gast.Load()))) else: # In this case we should NOT append RETURN_NO_VALUE placeholder if return_node.value is not None: cur_func_node = self.function_def[-1] if self.return_value_name[cur_func_node] is None: self.return_value_name[ cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) assign_nodes.append( gast.Assign(targets=[ gast.Name(id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=return_node.value)) stmt_list[i:] = assign_nodes return True
def visit_FunctionDef(self, node): self.function_def.append(node) self.return_value_name[node] = None self.return_name[node] = [] self.return_no_value_name[node] = [] self.pre_analysis = ReturnAnalysisVisitor(node) max_return_length = self.pre_analysis.get_func_max_return_length(node) while self.pre_analysis.get_func_return_count(node) > 1: self.generic_visit(node) self.pre_analysis = ReturnAnalysisVisitor(node) if max_return_length == 0: self.function_def.pop() return node # Prepend initialization of final return and append final return statement value_name = self.return_value_name[node] if value_name is not None: node.body.append( gast.Return(value=gast.Name(id=value_name, ctx=gast.Load(), annotation=None, type_comment=None))) init_names = [ unique_name.generate(RETURN_VALUE_INIT_NAME) for i in range(max_return_length) ] assign_zero_nodes = [ create_fill_constant_node(iname, 0.0) for iname in init_names ] if len(init_names) == 1: return_value_nodes = gast.Name(id=init_names[0], ctx=gast.Load(), annotation=None, type_comment=None) else: # We need to initialize return value as a tuple because control # flow requires some inputs or outputs have same structure return_value_nodes = gast.Tuple(elts=[ gast.Name(id=iname, ctx=gast.Load(), annotation=None, type_comment=None) for iname in init_names ], ctx=gast.Load()) assign_return_value_node = gast.Assign(targets=[ gast.Name(id=value_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=return_value_nodes) node.body.insert(0, assign_return_value_node) node.body[:0] = assign_zero_nodes # Prepend no value placeholders for name in self.return_no_value_name[node]: assign_no_value_node = create_fill_constant_node( name, RETURN_NO_VALUE_MAGIC_NUM) node.body.insert(0, assign_no_value_node) self.function_def.pop() return node