def test_index_in_list(self): list_to_test = [1, 2, 3, 4, 5] self.assertEqual(index_in_list(list_to_test, 4), 3) self.assertEqual(index_in_list(list_to_test, 1), 0) self.assertEqual(index_in_list(list_to_test, 5), 4) self.assertEqual(index_in_list(list_to_test, 0), -1) self.assertEqual(index_in_list(list_to_test, 6), -1)
def visit_Return(self, node): cur_func_node = self.function_def[-1] return_name = unique_name.generate(RETURN_PREFIX) self.return_name[cur_func_node].append(return_name) max_return_length = self.pre_analysis.get_func_max_return_length( cur_func_node) parent_node_of_return = self.ancestor_nodes[-2] for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)): ancestor = self.ancestor_nodes[ancestor_index] cur_node = self.ancestor_nodes[ancestor_index + 1] if hasattr(ancestor, "body") and index_in_list(ancestor.body, cur_node) != -1: if cur_node == node: self._replace_return_in_stmt_list( ancestor.body, cur_node, return_name, max_return_length, parent_node_of_return) self._replace_after_node_to_if_in_stmt_list( ancestor.body, cur_node, return_name, parent_node_of_return) elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse, cur_node) != -1: if cur_node == node: self._replace_return_in_stmt_list( ancestor.orelse, cur_node, return_name, max_return_length, parent_node_of_return) self._replace_after_node_to_if_in_stmt_list( ancestor.orelse, cur_node, return_name, parent_node_of_return) # If return node in while loop, add `not return_name` in gast.While.test if isinstance(ancestor, gast.While): cond_var_node = gast.UnaryOp( op=gast.Not(), operand=gast.Name( id=return_name, ctx=gast.Load(), annotation=None, type_comment=None)) ancestor.test = gast.BoolOp( op=gast.And(), values=[ancestor.test, cond_var_node]) continue # If return node in for loop, add `not return_name` in gast.While.test if isinstance(ancestor, gast.For): cond_var_node = gast.UnaryOp( op=gast.Not(), operand=gast.Name( id=return_name, ctx=gast.Load(), annotation=None, type_comment=None)) parent_node = self.ancestor_nodes[ancestor_index - 1] for_to_while = ForToWhileTransformer(parent_node, ancestor, cond_var_node) new_stmts = for_to_while.transform() while_node = new_stmts[-1] self.ancestor_nodes[ancestor_index] = while_node if ancestor == cur_func_node: break
def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node, return_name, parent_node_of_return): i = index_in_list(stmt_list, node) if i < 0 or i >= len(stmt_list): return False if i == len(stmt_list) - 1: # No need to add, we consider this as added successfully return True if_stmt = gast.If(test=gast.UnaryOp(op=gast.Not(), operand=gast.Name( id=return_name, ctx=gast.Store(), annotation=None, type_comment=None)), body=stmt_list[i + 1:], orelse=[]) stmt_list[i + 1:] = [if_stmt] # Here assume that the parent node of return is gast.If if isinstance(parent_node_of_return, gast.If): # Prepend control flow boolean nodes such as '__return@1 = False' node_str = "{} = _jst.create_bool_as_type({}, False)".format( return_name, ast_to_source_code(parent_node_of_return.test).strip()) assign_false_node = gast.parse(node_str).body[0] stmt_list[i:i] = [assign_false_node] return True
def _replace_break_continue_in_stmt_list( self, stmt_list, break_continue_node, break_continue_name): i = index_in_list(stmt_list, break_continue_node) if i == -1: return False assign_true_node = create_fill_constant_node(break_continue_name, True) stmt_list[i:] = [assign_true_node] return True
def transform(self): if hasattr(self.parent_node, 'body'): body_list = self.parent_node.body i = index_in_list(body_list, self.loop_node) if i != -1: new_stmts = self.get_for_stmt_nodes(body_list[i]) body_list[i:i + 1] = new_stmts i += len(new_stmts) return new_stmts if hasattr(self.parent_node, 'orelse'): body_list = self.parent_node.orelse i = index_in_list(body_list, self.loop_node) if i != -1: new_stmts = self.get_for_stmt_nodes(body_list[i]) body_list[i:i + 1] = new_stmts i += len(new_stmts) return new_stmts raise ValueError( "parent_node doesn't contain the loop_node in ForToWhileTransformer")
def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node, return_name): i = index_in_list(stmt_list, node) if i < 0 or i >= len(stmt_list): return False if i == len(stmt_list) - 1: # No need to add, we consider this as added successfully return True if_stmt = gast.If(test=gast.UnaryOp( op=gast.Not(), operand=gast.Name( id=return_name, ctx=gast.Store(), annotation=None, type_comment=None)), body=stmt_list[i + 1:], orelse=[]) stmt_list[i + 1:] = [if_stmt] return True
def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name, max_return_length, parent_node_of_return): assert max_return_length >= 0, "Input illegal max_return_length" i = index_in_list(stmt_list, return_node) if i == -1: return False assign_nodes = [] # Here assume that the parent node of return is gast.If if isinstance(parent_node_of_return, gast.If): # Prepend control flow boolean nodes such as '__return@1 = True' node_str = "{} = _jst.create_bool_as_type({}, True)".format( return_name, ast_to_source_code(parent_node_of_return.test).strip()) assign_true_node = gast.parse(node_str).body[0] assign_nodes.append(assign_true_node) cur_func_node = self.function_def[-1] return_length = get_return_size(return_node) if return_length < max_return_length: # In this case we should append RETURN_NO_VALUE placeholder # # max_return_length must be >= 1 here because return_length will be # 0 at least. if self.return_value_name[cur_func_node] is None: self.return_value_name[cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) no_value_names = [ unique_name.generate(RETURN_NO_VALUE_VAR_NAME) for j in range(max_return_length - return_length) ] self.return_no_value_name[cur_func_node].extend(no_value_names) # Handle tuple/non-tuple case if max_return_length == 1: assign_nodes.append( gast.Assign(targets=[ gast.Name(id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Name(id=no_value_names[0], ctx=gast.Load(), annotation=None, type_comment=None))) else: # max_return_length > 1 which means we should assign tuple fill_tuple = [ gast.Name(id=n, ctx=gast.Load(), annotation=None, type_comment=None) for n in no_value_names ] if return_node.value is not None: if isinstance(return_node.value, gast.Tuple): fill_tuple[:0] = return_node.value.elts else: fill_tuple.insert(0, return_node.value) assign_nodes.append( gast.Assign(targets=[ gast.Name(id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Tuple(elts=fill_tuple, ctx=gast.Load()))) else: # In this case we should NOT append RETURN_NO_VALUE placeholder if return_node.value is not None: cur_func_node = self.function_def[-1] if self.return_value_name[cur_func_node] is None: self.return_value_name[ cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) assign_nodes.append( gast.Assign(targets=[ gast.Name(id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=return_node.value)) stmt_list[i:] = assign_nodes return True
def _add_stmt_into_list_before_node(self, stmt_list, node, stmt_node): i = index_in_list(stmt_list, node) if i == -1: return False stmt_list.insert(i, stmt_node) return True
def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name, max_return_length): assert max_return_length >= 0, "Input illegal max_return_length" i = index_in_list(stmt_list, return_node) if i == -1: return False assign_nodes = [create_fill_constant_node(return_name, True)] cur_func_node = self.function_def[-1] return_length = get_return_size(return_node) if return_length < max_return_length: # In this case we should append RETURN_NO_VALUE placeholder # # max_return_length must be >= 1 here because return_length will be # 0 at least. if self.return_value_name[cur_func_node] is None: self.return_value_name[cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) no_value_names = [ unique_name.generate(RETURN_NO_VALUE_VAR_NAME) for j in range(max_return_length - return_length) ] self.return_no_value_name[cur_func_node].extend(no_value_names) # Handle tuple/non-tuple case if max_return_length == 1: assign_nodes.append( gast.Assign( targets=[ gast.Name( id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Name( id=no_value_names[0], ctx=gast.Load(), annotation=None, type_comment=None))) else: # max_return_length > 1 which means we should assign tuple fill_tuple = [ gast.Name( id=n, ctx=gast.Load(), annotation=None, type_comment=None) for n in no_value_names ] if return_node.value is not None: if isinstance(return_node.value, gast.Tuple): fill_tuple[:0] = return_node.value.elts else: fill_tuple.insert(0, return_node.value) assign_nodes.append( gast.Assign( targets=[ gast.Name( id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=gast.Tuple( elts=fill_tuple, ctx=gast.Load()))) else: # In this case we should NOT append RETURN_NO_VALUE placeholder if return_node.value is not None: cur_func_node = self.function_def[-1] if self.return_value_name[cur_func_node] is None: self.return_value_name[ cur_func_node] = unique_name.generate( RETURN_VALUE_PREFIX) assign_nodes.append( gast.Assign( targets=[ gast.Name( id=self.return_value_name[cur_func_node], ctx=gast.Store(), annotation=None, type_comment=None) ], value=return_node.value)) stmt_list[i:] = assign_nodes return True