class FixStripMainCheck(DelayBindBaseFix): PATTERN = """ if_stmt< 'if' comparison< '__name__' '==' '\\'__main__\\'' > ':' any* > | if_stmt< 'if' comparison< '__name__' '==' '"__main__"' > ':' any* > """ PATTERN_MAIN = PatternCompiler().compile_pattern(''' power< 'main' trailer< '(' ')' > > ''') def transform(self, node, result): suite = next( (x for x in node.children if x.type == python_symbols.suite), None) if not suite: return statements = [ x for x in suite.children if x.type == python_symbols.simple_stmt ] if len(statements) != 1: return if not self.PATTERN_MAIN.match(statements[0].children[0], {}): return return BlankLine()
def match(self, node): # Match the import patterns: results = {"node": node} match = self.pattern.match(node, results) if match and 'constantname' in results: # This is an "from import as" constantname = results['constantname'].value # Add a pattern to fix the usage of the constant # under this name: self.usage_patterns.append( PatternCompiler().compile_pattern( "constant='%s'"%constantname)) return results if match and 'importname' in results: # This is a "from import" without "as". # Add a pattern to fix the usage of the constant # under it's standard name: self.usage_patterns.append( PatternCompiler().compile_pattern( "constant='CONSTANT'")) return results if match and 'modulename' in results: # This is a "import as" modulename = results['modulename'].value # Add a pattern to fix the usage as an attribute: self.usage_patterns.append( PatternCompiler().compile_pattern( "power< '%s' trailer< '.' " \ "attribute='CONSTANT' > >" % modulename)) return results # Now do the usage patterns for pattern in self.usage_patterns: if pattern.match(node, results): return results
def compile_pattern(self): # Compile the import pattern. self.named_import_pattern = PatternCompiler().compile_pattern( self.IMPORT_PATTERN % {'function_name': self.FUNCTION_NAME})
class Function2DecoratorBase(BaseFix): IMPORT_PATTERN = """ import_from< 'from' dotted_name< 'zope' '.' 'interface' > 'import' import_as_names< any* (name='%(function_name)s') any* > > | import_from< 'from' dotted_name< 'zope' '.' 'interface' > 'import' name='%(function_name)s' any* > | import_from< 'from' dotted_name< 'zope' > 'import' name='interface' any* > | import_from< 'from' dotted_name< 'zope' '.' 'interface' > 'import' import_as_name< name='%(function_name)s' 'as' rename=(any) any*> > | import_from< 'from' dotted_name< 'zope' > 'import' import_as_name< name='interface' 'as' rename=(any) any*> > | import_from< 'from' 'zope' 'import' import_as_name< 'interface' 'as' interface_rename=(any) > > """ CLASS_PATTERN = """ decorated< decorator <any* > classdef< 'class' any* ':' suite< any* simple_stmt< power< statement=(%(match)s) trailer < '(' interface=any ')' > any* > any* > any* > > > | classdef< 'class' any* ':' suite< any* simple_stmt< power< statement=(%(match)s) trailer < '(' interface=any ')' > any* > any* > any* > > """ FUNCTION_PATTERN = """ simple_stmt< power< old_statement=(%s) trailer < '(' any* ')' > > any* > """ def should_skip(self, node): module = str(node) return not ('zope' in module and 'interface' in module) def compile_pattern(self): # Compile the import pattern. self.named_import_pattern = PatternCompiler().compile_pattern( self.IMPORT_PATTERN % {'function_name': self.FUNCTION_NAME}) def start_tree(self, tree, filename): # Compile the basic class/function matches. This is done per tree, # as further matches (based on what imports there are) also are done # per tree. self.class_patterns = [] self.function_patterns = [] self.fixups = [] self._add_pattern("'%s'" % self.FUNCTION_NAME) self._add_pattern("'interface' trailer< '.' '%s' >" % self.FUNCTION_NAME) self._add_pattern("'zope' trailer< '.' 'interface' > trailer< '.' '%s' >" % self.FUNCTION_NAME) def _add_pattern(self, match): self.class_patterns.append(PatternCompiler().compile_pattern( self.CLASS_PATTERN % {'match': match})) self.function_patterns.append(PatternCompiler().compile_pattern( self.FUNCTION_PATTERN % match)) def match(self, node): # Matches up the imports results = {"node": node} if self.named_import_pattern.match(node, results): return results # Now match classes on all import variants found: for pattern in self.class_patterns: if pattern.match(node, results): return results def transform(self, node, results): if 'name' in results: # This matched an import statement. Fix that up: name = results["name"] name.replace(Name(self.DECORATOR_NAME, prefix=name.prefix)) if 'rename' in results: # The import statement use import as self._add_pattern("'%s'" % results['rename'].value) if 'interface_rename' in results: self._add_pattern("'%s' trailer< '.' '%s' > " % ( results['interface_rename'].value, self.FUNCTION_NAME)) if 'statement' in results: # This matched a class that has an <FUNCTION_NAME>(IFoo) statement. # We must convert that statement to a class decorator # and put it before the class definition. statement = results['statement'] if not isinstance(statement, list): statement = [statement] # Make a copy for insertion before the class: statement = [x.clone() for x in statement] # Get rid of leading whitespace: statement[0].prefix = '' # Rename function to decorator: if statement[-1].children: func = statement[-1].children[-1] else: func = statement[-1] if func.value == self.FUNCTION_NAME: func.value = self.DECORATOR_NAME interface = results['interface'] if not isinstance(interface, list): interface = [interface] interface = [x.clone() for x in interface] # Create the decorator: decorator = Node(syms.decorator, [Leaf(50, '@'),] + statement + [Leaf(7, '(')] + interface + [Leaf(8, ')')]) # Take the current class constructor prefix, and stick it into # the decorator, to set the decorators indentation. nodeprefix = node.prefix decorator.prefix = nodeprefix # Preserve only the indent: if '\n' in nodeprefix: nodeprefix = nodeprefix[nodeprefix.rfind('\n')+1:] # Then find the last line of the previous node and use that as # indentation, and add that to the class constructors prefix. previous = node.prev_sibling if previous is None: prefix = '' else: prefix = str(previous) if '\n' in prefix: prefix = prefix[prefix.rfind('\n')+1:] prefix = prefix + nodeprefix if not prefix or prefix[0] != '\n': prefix = '\n' + prefix node.prefix = prefix new_node = Node(syms.decorated, [decorator, node.clone()]) # Look for the actual function calls in the new node and remove it. for node in new_node.post_order(): for pattern in self.function_patterns: if pattern.match(node, results): parent = node.parent previous = node.prev_sibling # Remove the node node.remove() if not str(parent).strip(): # This is an empty class. Stick in a pass if (len(parent.children) < 3 or ' ' in parent.children[2].value): # This class had no body whitespace. parent.insert_child(2, Leaf(0, ' pass')) else: # This class had body whitespace already. parent.insert_child(2, Leaf(0, 'pass')) parent.insert_child(3, Leaf(0, '\n')) elif (prefix and isinstance(previous, Leaf) and '\n' not in previous.value and previous.value.strip() == ''): # This is just whitespace, remove it: previous.remove() return new_node
NOTEQUALS = [Leaf(token.NOTEQUAL, "!=", prefix=" ")] IN = [Leaf(token.NAME, "in", prefix=" ")] NOT_IN = NOT + IN IS = [Leaf(token.NAME, "is", prefix=" ")] IS_NOT = IS + NOT NONE = Leaf(token.NAME, "None", prefix=" ") GREATER = [Leaf(token.GREATER, ">", prefix=" ")] GREATER_EQUAL = [Leaf(token.GREATEREQUAL, ">=", prefix=" ")] LESS = [Leaf(token.LESS, "<", prefix=" ")] LESS_EQUAL = [Leaf(token.LESSEQUAL, "<=", prefix=" ")] TRUE = Name("True") FALSE = Name("False") PC = PatternCompiler() IN_PATTERN = PC.compile_pattern("comparison< a=any 'in' b=any >") NOTIN_PATTERN = PC.compile_pattern( "comparison< a=any comp_op<'not' 'in'> b=any >") NUMBER_PATTERN = PC.compile_pattern("NUMBER | factor< ('+' | '-') NUMBER >") NOPAREN_PATTERN = PC.compile_pattern("power | atom") def make_operand(node): """Convert a node into something we can put in a statement. Adds parentheses if needed. """ if isinstance(node, Leaf) or NOPAREN_PATTERN.match(node): # No parentheses required in simple stuff result = [node.clone()]
NOTEQUALS = [Leaf(token.NOTEQUAL, "!=", prefix=" ")] IN = [Leaf(token.NAME, "in", prefix=" ")] NOT_IN = NOT + IN IS = [Leaf(token.NAME, "is", prefix=" ")] IS_NOT = IS + NOT NONE = Leaf(token.NAME, "None", prefix=" ") GREATER = [Leaf(token.GREATER, ">", prefix=" ")] GREATER_EQUAL = [Leaf(token.GREATEREQUAL, ">=", prefix=" ")] LESS = [Leaf(token.LESS, "<", prefix=" ")] LESS_EQUAL = [Leaf(token.LESSEQUAL, "<=", prefix=" ")] TRUE = Name("True") FALSE = Name("False") PC = PatternCompiler() IN_PATTERN = PC.compile_pattern("comparison< a=any 'in' b=any >") NOTIN_PATTERN = PC.compile_pattern("comparison< a=any comp_op<'not' 'in'> b=any >") NUMBER_PATTERN = PC.compile_pattern("NUMBER | factor< ('+' | '-') NUMBER >") NOPAREN_PATTERN = PC.compile_pattern("power | atom") def make_operand(node): """Convert a node into something we can put in a statement. Adds parentheses if needed. """ if isinstance(node, Leaf) or NOPAREN_PATTERN.match(node): # No parentheses required in simple stuff result = [node.clone()] else: # Parentheses required in complex statements (eg. assertEqual(x + y, 17))
def main(): parser = argparse.ArgumentParser() g1 = parser.add_mutually_exclusive_group(required=True) g1.add_argument("-pf", "--pattern-file", dest="pattern_file", type=str, help='Read pattern from the specified file') g1.add_argument("-ps", "--pattern-string", dest="pattern_string", type=str, help='A pattern string') g2 = parser.add_mutually_exclusive_group(required=True) g2.add_argument("-sf", "--source-file", dest="source_file", type=str, help="Read code snippet from the specified file") g2.add_argument("-ss", "--source-string", dest="source_string", type=str, help="A code snippet string") parser.add_argument("--print-results", dest="print_results", action='store_true', default=False, help="Print match results") parser.add_argument("--print-lineno", dest="print_lineno", action='store_true', default=False, help="Print match code with line number") # Parse command line arguments args = parser.parse_args() # parse source snippet to CST tree driver_ = driver.Driver(python_grammar, convert=pytree.convert) if args.source_file: tree = driver_.parse_file(args.source_file) else: tree = driver_.parse_stream(StringIO(args.source_string + "\n")) # compile pattern if args.pattern_file: with open(args.pattern_file, 'r') as f: pattern = f.read() else: pattern = args.pattern_string PC = PatternCompiler() pattern, pattern_tree = PC.compile_pattern(pattern, with_tree=True) for node in tree.post_order(): results = {'node': node} if pattern.match(node, results): match_node = results['node'] src_lines = str(match_node).splitlines() if args.print_lineno: # calculate lineno_list according to the right most leaf node. # because some node includes prefix, which is not a node, and we can't get it's lineno. right_most_leaf = match_node while not isinstance(right_most_leaf, pytree.Leaf): right_most_leaf = right_most_leaf.children[-1] last_lineno = right_most_leaf.get_lineno() lineno_list = list( range(last_lineno - len(src_lines) + 1, last_lineno + 1)) src_lines = [ str(lineno) + ' ' + line for lineno, line in zip(lineno_list, src_lines) ] for line in src_lines: print(line) if args.print_results: print(results) print('-' * 20)