def setUp(self): self.source = """ def test_fn(x, y): z = 1 if x > y: z = x * x z = z + y return z """ self.all_name_ids = { 'x': [ gast.Param(), gast.Load(), gast.Load(), gast.Load(), ], 'y': [ gast.Param(), gast.Load(), gast.Load(), ], 'z': [ gast.Store(), gast.Store(), gast.Load(), gast.Store(), gast.Load(), ] }
def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node, return_name, parent_node_of_return): i = index_in_list(stmt_list, node) if i < 0 or i >= len(stmt_list): return False if i == len(stmt_list) - 1: # No need to add, we consider this as added successfully return True if_stmt = gast.If(test=gast.UnaryOp(op=gast.Not(), operand=gast.Name( id=return_name, ctx=gast.Store(), annotation=None, type_comment=None)), body=stmt_list[i + 1:], orelse=[]) stmt_list[i + 1:] = [if_stmt] # Here assume that the parent node of return is gast.If if isinstance(parent_node_of_return, gast.If): # Prepend control flow boolean nodes such as '__return@1 = False' node_str = "{} = _jst.create_bool_as_type({}, False)".format( return_name, ast_to_source_code(parent_node_of_return.test).strip()) assign_false_node = gast.parse(node_str).body[0] stmt_list[i:i] = [assign_false_node] return True
def setUp(self): self.source = """ def test_fn(x, y): a = 1 x = y + a if x > y: z = x * x z = z + a else: z = y * y return z """ self.all_name_ids = { 'x': [ gast.Param(), gast.Store(), gast.Load(), gast.Load(), gast.Load() ], 'a': [gast.Store(), gast.Load(), gast.Load()], 'y': [ gast.Param(), gast.Load(), gast.Load(), gast.Load(), gast.Load(), ], 'z': [ gast.Store(), gast.Load(), gast.Store(), gast.Store(), gast.Load(), ] }
def visit_Name(self, node): if self._is_call_func_name_node(node): self.generic_visit(node) return if node.id in self.blacklist_names: self.generic_visit(node) return self.current_seen_vars.add(node) write_context = { type(gast.Store()), type(gast.AugStore()), type(gast.Del()) } for loop_node in self.current_loop: self.in_loop_vars[loop_node].append(node) if type(node.ctx) in write_context: self.write_in_loop[loop_node].add(node) if self.in_condition: self.condition_vars[loop_node].add(node) self.generic_visit(node)
def visit_If(self, node): """ For nested `if/else`, the created vars are not always visible for parent node. In addition, the vars created in `if.body` are not visible for `if.orelse`. Case 1: x = 1 if m > 1: res = new_tensor res = res + 1 # Error, `res` is not visible here. Case 2: if x_tensor > 0: res = new_tensor else: res = res + 1 # Error, `res` is not visible here. In above two cases, we should consider to manage the scope of vars to parsing the arguments and returned vars correctly. """ if not self._in_range or not self.end_node: self.generic_visit(node) return else: before_if_name_ids = copy.deepcopy(self.name_ids) body_name_ids = self._visit_child(node.body) # If traversal process stops early in `if.body`, return the currently seen name_ids. if not self._in_range: self._update_name_ids(before_if_name_ids) else: else_name_ids = self._visit_child(node.orelse) # If traversal process stops early in `if.orelse`, return the currently seen name_ids. if not self._in_range: self._update_name_ids(before_if_name_ids) else: # Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch # into name_ids. new_name_ids = self._find_new_name_ids( body_name_ids, else_name_ids) for new_name_id in new_name_ids: before_if_name_ids[new_name_id].append(gast.Store()) self.name_ids = before_if_name_ids
def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node, break_continue_name): i = index_in_list(stmt_list, node) if i == -1: return False if i == len(stmt_list) - 1: # No need to add, we consider this as added successfully return True if_stmt = gast.If(test=gast.UnaryOp(op=gast.Not(), operand=gast.Name( id=break_continue_name, ctx=gast.Store(), annotation=None, type_comment=None)), body=stmt_list[i + 1:], orelse=[]) stmt_list[i + 1:] = [] stmt_list.append(if_stmt) return True
def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name, max_return_length, parent_node_of_return): assert max_return_length >= 0, "Input illegal max_return_length" i = index_in_list(stmt_list, return_node) if i == -1: return False assign_nodes = [] # Here assume that the parent node of return is gast.If if isinstance(parent_node_of_return, gast.If): # Prepend control flow boolean nodes such as '__return@1 = True' node_str = "{} = _jst.create_bool_as_type({}, True)".format( return_name, ast_to_source_code(parent_node_of_return.test).strip()) assign_true_node = gast.parse(node_str).body[0] assign_nodes.append(assign_true_node) cur_func_node = self.function_def[-1] return_length = get_return_size(return_node) if return_length < max_return_length: # In this case we should append RETURN_NO_VALUE placeholder # # max_return_length must be >= 1 here because return_length will be # 0 at least. if self.return_value_name[cur_func_node] is None: self.return_value_name[cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) no_value_names = [ unique_name.generate(RETURN_NO_VALUE_VAR_NAME) for j in range(max_return_length - return_length) ] self.return_no_value_name[cur_func_node].extend(no_value_names) # Handle tuple/non-tuple case if max_return_length == 1: assign_nodes.append( gast.Assign(targets=[ gast.Name(id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Name(id=no_value_names[0], ctx=gast.Load(), annotation=None, type_comment=None))) else: # max_return_length > 1 which means we should assign tuple fill_tuple = [ gast.Name(id=n, ctx=gast.Load(), annotation=None, type_comment=None) for n in no_value_names ] if return_node.value is not None: if isinstance(return_node.value, gast.Tuple): fill_tuple[:0] = return_node.value.elts else: fill_tuple.insert(0, return_node.value) assign_nodes.append( gast.Assign(targets=[ gast.Name(id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Tuple(elts=fill_tuple, ctx=gast.Load()))) else: # In this case we should NOT append RETURN_NO_VALUE placeholder if return_node.value is not None: cur_func_node = self.function_def[-1] if self.return_value_name[cur_func_node] is None: self.return_value_name[ cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) assign_nodes.append( gast.Assign(targets=[ gast.Name(id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=return_node.value)) stmt_list[i:] = assign_nodes return True
def visit_FunctionDef(self, node): self.function_def.append(node) self.return_value_name[node] = None self.return_name[node] = [] self.return_no_value_name[node] = [] self.pre_analysis = ReturnAnalysisVisitor(node) max_return_length = self.pre_analysis.get_func_max_return_length(node) while self.pre_analysis.get_func_return_count(node) > 1: self.generic_visit(node) self.pre_analysis = ReturnAnalysisVisitor(node) if max_return_length == 0: self.function_def.pop() return node # Prepend initialization of final return and append final return statement value_name = self.return_value_name[node] if value_name is not None: node.body.append( gast.Return(value=gast.Name(id=value_name, ctx=gast.Load(), annotation=None, type_comment=None))) init_names = [ unique_name.generate(RETURN_VALUE_INIT_NAME) for i in range(max_return_length) ] assign_zero_nodes = [ create_fill_constant_node(iname, 0.0) for iname in init_names ] if len(init_names) == 1: return_value_nodes = gast.Name(id=init_names[0], ctx=gast.Load(), annotation=None, type_comment=None) else: # We need to initialize return value as a tuple because control # flow requires some inputs or outputs have same structure return_value_nodes = gast.Tuple(elts=[ gast.Name(id=iname, ctx=gast.Load(), annotation=None, type_comment=None) for iname in init_names ], ctx=gast.Load()) assign_return_value_node = gast.Assign(targets=[ gast.Name(id=value_name, ctx=gast.Store(), annotation=None, type_comment=None) ], value=return_value_nodes) node.body.insert(0, assign_return_value_node) node.body[:0] = assign_zero_nodes # Prepend no value placeholders for name in self.return_no_value_name[node]: assign_no_value_node = create_fill_constant_node( name, RETURN_NO_VALUE_MAGIC_NUM) node.body.insert(0, assign_no_value_node) self.function_def.pop() return node