IN = [Leaf(token.NAME, "in", prefix=" ")] NOT_IN = NOT + IN IS = [Leaf(token.NAME, "is", prefix=" ")] IS_NOT = IS + NOT NONE = Leaf(token.NAME, "None", prefix=" ") GREATER = [Leaf(token.GREATER, ">", prefix=" ")] GREATER_EQUAL = [Leaf(token.GREATEREQUAL, ">=", prefix=" ")] LESS = [Leaf(token.LESS, "<", prefix=" ")] LESS_EQUAL = [Leaf(token.LESSEQUAL, "<=", prefix=" ")] TRUE = Name("True") FALSE = Name("False") PC = PatternCompiler() IN_PATTERN = PC.compile_pattern("comparison< a=any 'in' b=any >") NOTIN_PATTERN = PC.compile_pattern("comparison< a=any comp_op<'not' 'in'> b=any >") NUMBER_PATTERN = PC.compile_pattern("NUMBER | factor< ('+' | '-') NUMBER >") NOPAREN_PATTERN = PC.compile_pattern("power | atom") def make_operand(node): """Convert a node into something we can put in a statement. Adds parentheses if needed. """ if isinstance(node, Leaf) or NOPAREN_PATTERN.match(node): # No parentheses required in simple stuff result = [node.clone()] else: # Parentheses required in complex statements (eg. assertEqual(x + y, 17)) result = [LParen(), node.clone(), RParen()]
IN = [Leaf(token.NAME, "in", prefix=" ")] NOT_IN = NOT + IN IS = [Leaf(token.NAME, "is", prefix=" ")] IS_NOT = IS + NOT NONE = Leaf(token.NAME, "None", prefix=" ") GREATER = [Leaf(token.GREATER, ">", prefix=" ")] GREATER_EQUAL = [Leaf(token.GREATEREQUAL, ">=", prefix=" ")] LESS = [Leaf(token.LESS, "<", prefix=" ")] LESS_EQUAL = [Leaf(token.LESSEQUAL, "<=", prefix=" ")] TRUE = Name("True") FALSE = Name("False") PC = PatternCompiler() IN_PATTERN = PC.compile_pattern("comparison< a=any 'in' b=any >") NOTIN_PATTERN = PC.compile_pattern( "comparison< a=any comp_op<'not' 'in'> b=any >") NUMBER_PATTERN = PC.compile_pattern("NUMBER | factor< ('+' | '-') NUMBER >") NOPAREN_PATTERN = PC.compile_pattern("power | atom") def make_operand(node): """Convert a node into something we can put in a statement. Adds parentheses if needed. """ if isinstance(node, Leaf) or NOPAREN_PATTERN.match(node): # No parentheses required in simple stuff result = [node.clone()] else:
def main(): parser = argparse.ArgumentParser() g1 = parser.add_mutually_exclusive_group(required=True) g1.add_argument("-pf", "--pattern-file", dest="pattern_file", type=str, help='Read pattern from the specified file') g1.add_argument("-ps", "--pattern-string", dest="pattern_string", type=str, help='A pattern string') g2 = parser.add_mutually_exclusive_group(required=True) g2.add_argument("-sf", "--source-file", dest="source_file", type=str, help="Read code snippet from the specified file") g2.add_argument("-ss", "--source-string", dest="source_string", type=str, help="A code snippet string") parser.add_argument("--print-results", dest="print_results", action='store_true', default=False, help="Print match results") parser.add_argument("--print-lineno", dest="print_lineno", action='store_true', default=False, help="Print match code with line number") # Parse command line arguments args = parser.parse_args() # parse source snippet to CST tree driver_ = driver.Driver(python_grammar, convert=pytree.convert) if args.source_file: tree = driver_.parse_file(args.source_file) else: tree = driver_.parse_stream(StringIO(args.source_string + "\n")) # compile pattern if args.pattern_file: with open(args.pattern_file, 'r') as f: pattern = f.read() else: pattern = args.pattern_string PC = PatternCompiler() pattern, pattern_tree = PC.compile_pattern(pattern, with_tree=True) for node in tree.post_order(): results = {'node': node} if pattern.match(node, results): match_node = results['node'] src_lines = str(match_node).splitlines() if args.print_lineno: # calculate lineno_list according to the right most leaf node. # because some node includes prefix, which is not a node, and we can't get it's lineno. right_most_leaf = match_node while not isinstance(right_most_leaf, pytree.Leaf): right_most_leaf = right_most_leaf.children[-1] last_lineno = right_most_leaf.get_lineno() lineno_list = list( range(last_lineno - len(src_lines) + 1, last_lineno + 1)) src_lines = [ str(lineno) + ' ' + line for lineno, line in zip(lineno_list, src_lines) ] for line in src_lines: print(line) if args.print_results: print(results) print('-' * 20)