class FixStripMainCheck(DelayBindBaseFix): PATTERN = """ if_stmt< 'if' comparison< '__name__' '==' '\\'__main__\\'' > ':' any* > | if_stmt< 'if' comparison< '__name__' '==' '"__main__"' > ':' any* > """ PATTERN_MAIN = PatternCompiler().compile_pattern(''' power< 'main' trailer< '(' ')' > > ''') def transform(self, node, result): suite = next( (x for x in node.children if x.type == python_symbols.suite), None) if not suite: return statements = [ x for x in suite.children if x.type == python_symbols.simple_stmt ] if len(statements) != 1: return if not self.PATTERN_MAIN.match(statements[0].children[0], {}): return return BlankLine()
def match(self, node): # Match the import patterns: results = {"node": node} match = self.pattern.match(node, results) if match and 'constantname' in results: # This is an "from import as" constantname = results['constantname'].value # Add a pattern to fix the usage of the constant # under this name: self.usage_patterns.append( PatternCompiler().compile_pattern( "constant='%s'"%constantname)) return results if match and 'importname' in results: # This is a "from import" without "as". # Add a pattern to fix the usage of the constant # under it's standard name: self.usage_patterns.append( PatternCompiler().compile_pattern( "constant='CONSTANT'")) return results if match and 'modulename' in results: # This is a "import as" modulename = results['modulename'].value # Add a pattern to fix the usage as an attribute: self.usage_patterns.append( PatternCompiler().compile_pattern( "power< '%s' trailer< '.' " \ "attribute='CONSTANT' > >" % modulename)) return results # Now do the usage patterns for pattern in self.usage_patterns: if pattern.match(node, results): return results
NOTEQUALS = [Leaf(token.NOTEQUAL, "!=", prefix=" ")] IN = [Leaf(token.NAME, "in", prefix=" ")] NOT_IN = NOT + IN IS = [Leaf(token.NAME, "is", prefix=" ")] IS_NOT = IS + NOT NONE = Leaf(token.NAME, "None", prefix=" ") GREATER = [Leaf(token.GREATER, ">", prefix=" ")] GREATER_EQUAL = [Leaf(token.GREATEREQUAL, ">=", prefix=" ")] LESS = [Leaf(token.LESS, "<", prefix=" ")] LESS_EQUAL = [Leaf(token.LESSEQUAL, "<=", prefix=" ")] TRUE = Name("True") FALSE = Name("False") PC = PatternCompiler() IN_PATTERN = PC.compile_pattern("comparison< a=any 'in' b=any >") NOTIN_PATTERN = PC.compile_pattern( "comparison< a=any comp_op<'not' 'in'> b=any >") NUMBER_PATTERN = PC.compile_pattern("NUMBER | factor< ('+' | '-') NUMBER >") NOPAREN_PATTERN = PC.compile_pattern("power | atom") def make_operand(node): """Convert a node into something we can put in a statement. Adds parentheses if needed. """ if isinstance(node, Leaf) or NOPAREN_PATTERN.match(node): # No parentheses required in simple stuff result = [node.clone()]
def main(): parser = argparse.ArgumentParser() g1 = parser.add_mutually_exclusive_group(required=True) g1.add_argument("-pf", "--pattern-file", dest="pattern_file", type=str, help='Read pattern from the specified file') g1.add_argument("-ps", "--pattern-string", dest="pattern_string", type=str, help='A pattern string') g2 = parser.add_mutually_exclusive_group(required=True) g2.add_argument("-sf", "--source-file", dest="source_file", type=str, help="Read code snippet from the specified file") g2.add_argument("-ss", "--source-string", dest="source_string", type=str, help="A code snippet string") parser.add_argument("--print-results", dest="print_results", action='store_true', default=False, help="Print match results") parser.add_argument("--print-lineno", dest="print_lineno", action='store_true', default=False, help="Print match code with line number") # Parse command line arguments args = parser.parse_args() # parse source snippet to CST tree driver_ = driver.Driver(python_grammar, convert=pytree.convert) if args.source_file: tree = driver_.parse_file(args.source_file) else: tree = driver_.parse_stream(StringIO(args.source_string + "\n")) # compile pattern if args.pattern_file: with open(args.pattern_file, 'r') as f: pattern = f.read() else: pattern = args.pattern_string PC = PatternCompiler() pattern, pattern_tree = PC.compile_pattern(pattern, with_tree=True) for node in tree.post_order(): results = {'node': node} if pattern.match(node, results): match_node = results['node'] src_lines = str(match_node).splitlines() if args.print_lineno: # calculate lineno_list according to the right most leaf node. # because some node includes prefix, which is not a node, and we can't get it's lineno. right_most_leaf = match_node while not isinstance(right_most_leaf, pytree.Leaf): right_most_leaf = right_most_leaf.children[-1] last_lineno = right_most_leaf.get_lineno() lineno_list = list( range(last_lineno - len(src_lines) + 1, last_lineno + 1)) src_lines = [ str(lineno) + ' ' + line for lineno, line in zip(lineno_list, src_lines) ] for line in src_lines: print(line) if args.print_results: print(results) print('-' * 20)