class FixStripMainCheck(DelayBindBaseFix):

    PATTERN = """
    if_stmt<
      'if' comparison< '__name__' '==' '\\'__main__\\'' > ':' any*
    >
    |
    if_stmt<
      'if' comparison< '__name__' '==' '"__main__"' > ':' any*
    >
  """

    PATTERN_MAIN = PatternCompiler().compile_pattern('''
    power< 'main' trailer< '(' ')' > >
  ''')

    def transform(self, node, result):
        suite = next(
            (x for x in node.children if x.type == python_symbols.suite), None)
        if not suite: return
        statements = [
            x for x in suite.children if x.type == python_symbols.simple_stmt
        ]
        if len(statements) != 1: return
        if not self.PATTERN_MAIN.match(statements[0].children[0], {}): return
        return BlankLine()
Example #2
0
 def match(self, node):
     # Match the import patterns:
     results = {"node": node}
     match = self.pattern.match(node, results)
     
     if match and 'constantname' in results:
         # This is an "from import as"
         constantname = results['constantname'].value
         # Add a pattern to fix the usage of the constant
         # under this name:
         self.usage_patterns.append(
             PatternCompiler().compile_pattern(
                 "constant='%s'"%constantname))
         return results
     
     if match and 'importname' in results:
         # This is a "from import" without "as".
         # Add a pattern to fix the usage of the constant
         # under it's standard name:
         self.usage_patterns.append(
             PatternCompiler().compile_pattern(
                 "constant='CONSTANT'"))
         return results
     
     if match and 'modulename' in results:
         # This is a "import as"
         modulename = results['modulename'].value
         # Add a pattern to fix the usage as an attribute:
         self.usage_patterns.append(
             PatternCompiler().compile_pattern(
             "power< '%s' trailer< '.' " \
             "attribute='CONSTANT' > >" % modulename))
         return results
     
     # Now do the usage patterns
     for pattern in self.usage_patterns:
         if pattern.match(node, results):
             return results
Example #3
0
 def compile_pattern(self):
     # Compile the import pattern.
     self.named_import_pattern = PatternCompiler().compile_pattern(
         self.IMPORT_PATTERN % {'function_name': self.FUNCTION_NAME})
Example #4
0
class Function2DecoratorBase(BaseFix):

    IMPORT_PATTERN = """
    import_from< 'from' dotted_name< 'zope' '.' 'interface' > 'import' import_as_names< any* (name='%(function_name)s') any* > >
    |
    import_from< 'from' dotted_name< 'zope' '.' 'interface' > 'import' name='%(function_name)s' any* >
    |
    import_from< 'from' dotted_name< 'zope' > 'import' name='interface' any* >
    |
    import_from< 'from' dotted_name< 'zope' '.' 'interface' > 'import' import_as_name< name='%(function_name)s' 'as' rename=(any) any*> >
    |
    import_from< 'from' dotted_name< 'zope' > 'import' import_as_name< name='interface' 'as' rename=(any) any*> >
    |
    import_from< 'from' 'zope' 'import' import_as_name< 'interface' 'as' interface_rename=(any) > >
    """

    CLASS_PATTERN = """
    decorated< decorator <any* > classdef< 'class' any* ':' suite< any* simple_stmt< power< statement=(%(match)s) trailer < '(' interface=any ')' > any* > any* > any* > > >
    |
    classdef< 'class' any* ':' suite< any* simple_stmt< power< statement=(%(match)s) trailer < '(' interface=any ')' > any* > any* > any* > >
    """

    FUNCTION_PATTERN = """
    simple_stmt< power< old_statement=(%s) trailer < '(' any* ')' > > any* >
    """

    def should_skip(self, node):
        module = str(node)
        return not ('zope' in module and 'interface' in module)

    def compile_pattern(self):
        # Compile the import pattern.
        self.named_import_pattern = PatternCompiler().compile_pattern(
            self.IMPORT_PATTERN % {'function_name': self.FUNCTION_NAME})

    def start_tree(self, tree, filename):
        # Compile the basic class/function matches. This is done per tree,
        # as further matches (based on what imports there are) also are done
        # per tree.
        self.class_patterns = []
        self.function_patterns = []
        self.fixups = []

        self._add_pattern("'%s'" % self.FUNCTION_NAME)
        self._add_pattern("'interface' trailer< '.' '%s' >" % self.FUNCTION_NAME)
        self._add_pattern("'zope' trailer< '.' 'interface' > trailer< '.' '%s' >" % self.FUNCTION_NAME)

    def _add_pattern(self, match):
            self.class_patterns.append(PatternCompiler().compile_pattern(
                self.CLASS_PATTERN % {'match': match}))
            self.function_patterns.append(PatternCompiler().compile_pattern(
                self.FUNCTION_PATTERN % match))

    def match(self, node):
        # Matches up the imports
        results = {"node": node}
        if self.named_import_pattern.match(node, results):
            return results

        # Now match classes on all import variants found:
        for pattern in self.class_patterns:
            if pattern.match(node, results):
                return results

    def transform(self, node, results):
        if 'name' in results:
            # This matched an import statement. Fix that up:
            name = results["name"]
            name.replace(Name(self.DECORATOR_NAME, prefix=name.prefix))
        if 'rename' in results:
            # The import statement use import as
            self._add_pattern("'%s'" % results['rename'].value)
        if 'interface_rename' in results:
            self._add_pattern("'%s' trailer< '.' '%s' > " % (
                results['interface_rename'].value, self.FUNCTION_NAME))
        if 'statement' in results:
            # This matched a class that has an <FUNCTION_NAME>(IFoo) statement.
            # We must convert that statement to a class decorator
            # and put it before the class definition.

            statement = results['statement']
            if not isinstance(statement, list):
                statement = [statement]
            # Make a copy for insertion before the class:
            statement = [x.clone() for x in statement]
            # Get rid of leading whitespace:
            statement[0].prefix = ''
            # Rename function to decorator:
            if statement[-1].children:
                func = statement[-1].children[-1]
            else:
                func = statement[-1]
            if func.value == self.FUNCTION_NAME:
                func.value = self.DECORATOR_NAME

            interface = results['interface']
            if not isinstance(interface, list):
                interface = [interface]
            interface = [x.clone() for x in interface]

            # Create the decorator:
            decorator = Node(syms.decorator, [Leaf(50, '@'),] + statement +
                             [Leaf(7, '(')] + interface + [Leaf(8, ')')])

            # Take the current class constructor prefix, and stick it into
            # the decorator, to set the decorators indentation.
            nodeprefix = node.prefix
            decorator.prefix = nodeprefix
            # Preserve only the indent:
            if '\n' in nodeprefix:
                nodeprefix = nodeprefix[nodeprefix.rfind('\n')+1:]

            # Then find the last line of the previous node and use that as
            # indentation, and add that to the class constructors prefix.

            previous = node.prev_sibling
            if previous is None:
                prefix = ''
            else:
                prefix = str(previous)
            if '\n' in prefix:
                prefix = prefix[prefix.rfind('\n')+1:]
            prefix = prefix + nodeprefix

            if not prefix or prefix[0] != '\n':
                prefix = '\n' + prefix
            node.prefix = prefix
            new_node = Node(syms.decorated, [decorator, node.clone()])
            # Look for the actual function calls in the new node and remove it.
            for node in new_node.post_order():
                for pattern in self.function_patterns:
                    if pattern.match(node, results):
                        parent = node.parent
                        previous = node.prev_sibling
                        # Remove the node
                        node.remove()
                        if not str(parent).strip():
                            # This is an empty class. Stick in a pass
                            if (len(parent.children) < 3 or
                                ' ' in parent.children[2].value):
                                # This class had no body whitespace.
                                parent.insert_child(2, Leaf(0, '    pass'))
                            else:
                                # This class had body whitespace already.
                                parent.insert_child(2, Leaf(0, 'pass'))
                            parent.insert_child(3, Leaf(0, '\n'))
                        elif (prefix and isinstance(previous, Leaf) and
                            '\n' not in previous.value and
                            previous.value.strip() == ''):
                            # This is just whitespace, remove it:
                            previous.remove()

            return new_node
Example #5
0
NOTEQUALS = [Leaf(token.NOTEQUAL, "!=", prefix=" ")]
IN = [Leaf(token.NAME, "in", prefix=" ")]
NOT_IN = NOT + IN
IS = [Leaf(token.NAME, "is", prefix=" ")]
IS_NOT = IS + NOT
NONE = Leaf(token.NAME, "None", prefix=" ")

GREATER = [Leaf(token.GREATER, ">", prefix=" ")]
GREATER_EQUAL = [Leaf(token.GREATEREQUAL, ">=", prefix=" ")]
LESS = [Leaf(token.LESS, "<", prefix=" ")]
LESS_EQUAL = [Leaf(token.LESSEQUAL, "<=", prefix=" ")]

TRUE = Name("True")
FALSE = Name("False")

PC = PatternCompiler()
IN_PATTERN = PC.compile_pattern("comparison< a=any 'in' b=any >")
NOTIN_PATTERN = PC.compile_pattern(
    "comparison< a=any comp_op<'not' 'in'> b=any >")
NUMBER_PATTERN = PC.compile_pattern("NUMBER | factor< ('+' | '-') NUMBER >")
NOPAREN_PATTERN = PC.compile_pattern("power | atom")


def make_operand(node):
    """Convert a node into something we can put in a statement.

    Adds parentheses if needed.
    """
    if isinstance(node, Leaf) or NOPAREN_PATTERN.match(node):
        # No parentheses required in simple stuff
        result = [node.clone()]
Example #6
0
NOTEQUALS = [Leaf(token.NOTEQUAL, "!=", prefix=" ")]
IN = [Leaf(token.NAME, "in", prefix=" ")]
NOT_IN = NOT + IN
IS = [Leaf(token.NAME, "is", prefix=" ")]
IS_NOT = IS + NOT
NONE = Leaf(token.NAME, "None", prefix=" ")

GREATER = [Leaf(token.GREATER, ">", prefix=" ")]
GREATER_EQUAL = [Leaf(token.GREATEREQUAL, ">=", prefix=" ")]
LESS = [Leaf(token.LESS, "<", prefix=" ")]
LESS_EQUAL = [Leaf(token.LESSEQUAL, "<=", prefix=" ")]

TRUE = Name("True")
FALSE = Name("False")

PC = PatternCompiler()
IN_PATTERN = PC.compile_pattern("comparison< a=any 'in' b=any >")
NOTIN_PATTERN = PC.compile_pattern("comparison< a=any comp_op<'not' 'in'> b=any >")
NUMBER_PATTERN = PC.compile_pattern("NUMBER | factor< ('+' | '-') NUMBER >")
NOPAREN_PATTERN = PC.compile_pattern("power | atom")

def make_operand(node):
    """Convert a node into something we can put in a statement.

    Adds parentheses if needed.
    """
    if isinstance(node, Leaf) or NOPAREN_PATTERN.match(node):
        # No parentheses required in simple stuff
        result = [node.clone()]
    else:
        # Parentheses required in complex statements (eg. assertEqual(x + y, 17))
Example #7
0
def main():
    parser = argparse.ArgumentParser()
    g1 = parser.add_mutually_exclusive_group(required=True)
    g1.add_argument("-pf",
                    "--pattern-file",
                    dest="pattern_file",
                    type=str,
                    help='Read pattern from the specified file')
    g1.add_argument("-ps",
                    "--pattern-string",
                    dest="pattern_string",
                    type=str,
                    help='A pattern string')
    g2 = parser.add_mutually_exclusive_group(required=True)
    g2.add_argument("-sf",
                    "--source-file",
                    dest="source_file",
                    type=str,
                    help="Read code snippet from the specified file")
    g2.add_argument("-ss",
                    "--source-string",
                    dest="source_string",
                    type=str,
                    help="A code snippet string")
    parser.add_argument("--print-results",
                        dest="print_results",
                        action='store_true',
                        default=False,
                        help="Print match results")
    parser.add_argument("--print-lineno",
                        dest="print_lineno",
                        action='store_true',
                        default=False,
                        help="Print match code with line number")
    # Parse command line arguments
    args = parser.parse_args()

    # parse source snippet to CST tree
    driver_ = driver.Driver(python_grammar, convert=pytree.convert)
    if args.source_file:
        tree = driver_.parse_file(args.source_file)
    else:
        tree = driver_.parse_stream(StringIO(args.source_string + "\n"))
    # compile pattern
    if args.pattern_file:
        with open(args.pattern_file, 'r') as f:
            pattern = f.read()
    else:
        pattern = args.pattern_string
    PC = PatternCompiler()
    pattern, pattern_tree = PC.compile_pattern(pattern, with_tree=True)
    for node in tree.post_order():
        results = {'node': node}
        if pattern.match(node, results):
            match_node = results['node']
            src_lines = str(match_node).splitlines()
            if args.print_lineno:
                # calculate lineno_list according to the right most leaf node.
                # because some node includes prefix, which is not a node, and we can't get it's lineno.
                right_most_leaf = match_node
                while not isinstance(right_most_leaf, pytree.Leaf):
                    right_most_leaf = right_most_leaf.children[-1]
                last_lineno = right_most_leaf.get_lineno()
                lineno_list = list(
                    range(last_lineno - len(src_lines) + 1, last_lineno + 1))
                src_lines = [
                    str(lineno) + ' ' + line
                    for lineno, line in zip(lineno_list, src_lines)
                ]
            for line in src_lines:
                print(line)
            if args.print_results:
                print(results)
            print('-' * 20)