class NamedFunctionDefinition(matcher.Matcher): """A matcher for a named function definition. This includes both regular functions and async functions. Args: body: The matcher for the function body. returns: The matcher for the return type annotation. """ _body = matcher.submatcher_attrib(default=base_matchers.Anything()) _returns = matcher.submatcher_attrib(default=base_matchers.Anything()) @cached_property.cached_property def _matcher(self): kwargs = {'body': self._body} # We check for the existence of `returns` as an AST field, instead of # checking the Python version, to support backports of the type annotation # syntax to Python 2. if 'returns' in attr.fields_dict(ast_matchers.FunctionDef): kwargs['returns'] = self._returns function_def = ast_matchers.FunctionDef(**kwargs) if six.PY3: function_def = base_matchers.AnyOf( ast_matchers.AsyncFunctionDef(**kwargs), function_def, ) return function_def def _match(self, context, candidate): return self._matcher.match(context, candidate) @cached_property.cached_property def type_filter(self): return self._matcher.type_filter
class HasFirstAncestor(matcher.Matcher): """The first ancestor to match ``first_ancestor`` also matches ``also_matches``. For example, "the function that I am currently in is a generator function" is a matcher that one might want to create, and can be created using ``HasFirstAncestor``. """ _first_ancestor = matcher.submatcher_attrib() _also_matches = matcher.submatcher_attrib() def _match(self, context, candidate): parent = candidate while True: parent = context.parsed_file.nav.get_parent(parent) if parent is None: return None m = self._first_ancestor.match(context, parent) if m is not None: break ancestor = m.match.matched m2 = self._also_matches.match(context, ancestor) if m2 is None: return None return matcher.MatchInfo( matcher.create_match(context.parsed_file, candidate), matcher.merge_bindings(m.bindings, m2.bindings))
class WithReplacements(matcher.Matcher): submatcher = matcher.submatcher_attrib() replacements = attr.ib(type=Dict[str, formatting.Template]) def __attrs_post_init__(self): missing_labels = formatting.template_variables( self.replacements) - self.bind_variables # System labels don't count. missing_labels = { label for label in missing_labels if not label.startswith('__') } if missing_labels: raise ValueError( 'The substitution template(s) referenced variables not matched in the Python matcher: {variables}' .format(variables=', '.join('`{}`'.format(v) for v in sorted(missing_labels)))) @cached_property.cached_property def type_filter(self): return self.submatcher.type_filter def _match(self, context, candidate): mi = self.submatcher.match(context, candidate) if mi is None: return None return attr.evolve(mi, replacements=matcher.merge_replacements( mi.replacements, self.replacements))
class Once(matcher.Matcher): """Runs the submatcher at most once successfully. Matches if the submatcher has ever matched, including in this run. Fails if the matcher has not ever matched. If ``key`` is provided, then any other ``Once()`` with the same key shares state, and is considered equivalent for the sake of the above. """ _submatcher = matcher.submatcher_attrib() _key = attr.ib(type=Hashable) @_key.default def _key_default(self): return self @matcher.accumulating_matcher def _match(self, context, candidate): if context.has_run(self._key): return m = self._submatcher.match(context, candidate) if m is not None: context.set_has_run(self._key) yield m type_filter = None
class Str(matcher.Matcher): s = matcher.submatcher_attrib(default=base_matchers.Anything()) def _match(self, context, candidate): return _constant_match(context, candidate, self.s, (str, )) type_filter = frozenset({ast.Constant})
class Num(matcher.Matcher): n = matcher.submatcher_attrib(default=base_matchers.Anything()) def _match(self, context, candidate): return _constant_match(context, candidate, self.n, (int, float, complex)) type_filter = frozenset({ast.Constant})
class NameConstant(matcher.Matcher): value = matcher.submatcher_attrib(default=base_matchers.Anything()) def _match(self, context, candidate): return _constant_match(context, candidate, self.value, (bool, type(None))) type_filter = frozenset({ast.Constant})
class MatchesRegex(matcher.Matcher): """Matches a candidate iff it matches the ``regex``. The match must be complete -- the regex must match the full AST, not just a substring of it. (i.e. this has ``re.fullmatch`` semantics.) Any named groups are added to the bindings -- e.g. ``(xyz)`` does not add anything to the bindings, but ``(?P<name>xyz)`` will bind ``name`` to the subspan ``'xyz'``. The bound matches are neither lexical nor syntactic, but purely on codepoint spans. """ _regex = attr.ib(type=str) _subpattern = matcher.submatcher_attrib(default=Anything(), type=matcher.Matcher) @cached_property.cached_property def _wrapped_regex(self): """Wrapped regex with fullmatch semantics on match().""" # fullmatch is anchored to both the start and end of the attempted span. # since match is anchored at the start, we only need to anchor the end. # $ works for this. (Interestingly, ^ wouldn't work for anchoring at the # start of the span.) # This is a hack to maintain Python 2 compatibility until this can be # 3-only. return re.compile('(?:%s)$' % self._regex) def _match(self, context, candidate): matchinfo = self._subpattern.match(context, candidate) if matchinfo is None: return None span = matchinfo.match.span if span is None: return None # can't search within this AST node. try: m = self._wrapped_regex.match(context.parsed_file.text, *span) except TypeError: return None if m is None: return None # TODO(b/118507248): Allow choosing a different binding type. bindings = matcher.merge_bindings( _re_match_to_bindings(self._wrapped_regex, context.parsed_file.text, m), matchinfo.bindings) if bindings is None: return None return attr.evolve(matchinfo, bindings=bindings) @cached_property.cached_property def bind_variables(self): return frozenset( self._wrapped_regex.groupindex) | self._subpattern.bind_variables
class HasNextSibling(matcher.Matcher): """Matches a node if the immediate next sibling in the node list matches ``submatcher``.""" _submatcher = matcher.submatcher_attrib() def _match(self, context, candidate): sibling = context.parsed_file.nav.get_next_sibling(candidate) if sibling: return self._submatcher.match(context, sibling) return None
class InNamedFunction(matcher.Matcher): """Matches anything directly inside of a function that matches ``submatcher``.""" _submatcher = matcher.submatcher_attrib() @cached_property.cached_property def _recursive_matcher(self): return HasFirstAncestor(first_ancestor=NamedFunctionDefinition(), also_matches=self._submatcher) def _match(self, context, candidate): return self._recursive_matcher.match(context, candidate)
class Subscript(Subscript): # pylint: disable=undefined-variable slice = matcher.submatcher_attrib(default=base_matchers.Anything()) @slice.validator def _slice_validator(self, attribute, value): del attribute # unused if isinstance(value, base_matchers.Bind): raise ValueError( 'slice=Bind(...) not supported in Python < 3.9. It will fail to ' 'correctly match e.g. `a[:]` or `a[1,:]`. Upgrade to Python 3.9, or' ' work around this using AllOf(Bind(...)) if that is OK.')
class HasDescendant(matcher.Matcher): """Matches an AST node if any descendant matches the submatcher. This is equivalent to ``HasChild(IsOrHasDescendant(...))``. """ _submatcher = matcher.submatcher_attrib() @cached_property.cached_property def _recursive_matcher(self): return HasChild(IsOrHasDescendant(self._submatcher)) def _match(self, context, candidate): return self._recursive_matcher.match(context, candidate)
class Unless(matcher.Matcher): """Inverts a matcher and discard its bindings.""" _submatcher = matcher.submatcher_attrib(walk=False) def _match(self, context, candidate): if self._submatcher.match(context, candidate) is None: return matcher.MatchInfo( matcher.create_match(context.parsed_file, candidate)) else: return None # TODO: Maybe, someday, do stratified datalog with negation. type_filter = None
class IsOrHasDescendant(matcher.Matcher): """Matches a candidate if it or any descendant matches the submatcher. If the candidate directly matches, then that match is returned. Otherwise, the candidate is recursively traversed using ``HasChild`` until a match is found. """ _submatcher = matcher.submatcher_attrib() @cached_property.cached_property def _recursive_matcher(self): return base_matchers.RecursivelyWrapped(self._submatcher, HasChild) def _match(self, context, candidate): return self._recursive_matcher.match(context, candidate)
def _generate_syntax_matcher(cls, ast_node_type): # Generate a class with an attrs field for every AST field, passed by # keyword argument only. ty = attr.make_class( ast_node_type.__name__, { field: matcher.submatcher_attrib( default=base_matchers.Anything(), ) for field in ast_node_type._fields }, bases=(cls, ), frozen=True, kw_only=True, ) ty._ast_type = ast_node_type # pylint: disable=protected-access ty.type_filter = frozenset({ast_node_type}) return ty
class _BaseAstPattern(matcher.Matcher): """Base class for AST patterns. Subclasses should implement a _pull_ast(module_ast) method which returns the AST to match from that module. """ # store the init parameters for a pretty repr. pattern = attr.ib() # type: Text restrictions = attr.ib( default=attr.Factory(dict)) # type: Dict[Text, matcher.Matcher] _ast_matcher = matcher.submatcher_attrib( repr=False, init=False, default=attr.Factory( lambda self: self._get_matcher(), # pylint: disable=protected-access takes_self=True), ) # type: matcher.Matcher def _get_matcher(self): try: remapped_pattern, variable_names, variables = _rewrite_submatchers( self.pattern, self.restrictions) parsed_ast = ast.parse(remapped_pattern) except SyntaxError as e: raise ValueError('Failed to parse %r: %s' % (self.pattern, e)) _verify_variables(parsed_ast, variable_names) intended_match_ast = self._pull_ast(parsed_ast) return base_matchers.Rebind( _ast_pattern(intended_match_ast, variables), on_conflict=matcher.BindConflict.MERGE, on_merge=matcher.BindMerge.KEEP_LAST, ) @abc.abstractmethod def _pull_ast(self, module_ast): """Given an ast.Module, returns the AST to match precisely.""" raise NotImplementedError # not MI friendly, but whatever. def _match(self, context, candidate): return self._ast_matcher.match(context, candidate) @cached_property.cached_property def type_filter(self): return self._ast_matcher.type_filter
class _Recurse(matcher.Matcher): """Recursion barrier for RecursivelyWrapped which avoids infinite loops.""" # Deliberately removing from equality checks, since it will only ever point # to the RecursivelyWrapped node at a similar location. The assumption # is that they can only ever be created from a RecursivelyWrapped and the tie # to their parent is "hidden". # We also remove from .bind_variables walking to avoid infinite recursion. _recurse_to = matcher.submatcher_attrib(eq=False, order=False, walk=False) def _match(self, *args, **kwargs): return self._recurse_to.match(*args, **kwargs) def __repr__(self): return '%s(...)' % type(self).__name__ @cached_property.cached_property def type_filter(self): return self._recurse_to.type_filter
class NoComments(matcher.Matcher): """Filter results to only those lexical spans that have no comments inside. Args: submatcher: A Matcher matching a LexicalMatch. """ _submatcher = matcher.submatcher_attrib() # type: matcher.Matcher def _match(self, context, candidate): result = self._submatcher.match(context, candidate) if _result_has_comments(context, self._submatcher, result): return None else: return result @cached_property.cached_property def type_filter(self): return self._submatcher.type_filter
class HasChild(matcher.Matcher): """Matches an AST node if a direct child matches the submatcher. An AST node in this context is considered to be an AST object or a list object. Only direct children are yielded -- ``AST.member`` or ``list[index]``. There is no recursive traversal of any kind. Fails the match if the candidate node is not an AST object or list. """ _submatcher = matcher.submatcher_attrib() def _match(self, context, candidate): for child in _ast_children(candidate): m = self._submatcher.match(context, child) if m is None: continue return matcher.MatchInfo( matcher.create_match(context.parsed_file, candidate), m.bindings) return None
class Contains(matcher.Matcher): """Matches a collection if any item matches the given matcher. Fails the match if the candidate is not iterable. """ _submatcher = matcher.submatcher_attrib() def _match(self, context, candidate): try: items = iter(candidate) except TypeError: return None for can in items: m = self._submatcher.match(context, can) if m is not None: return matcher.MatchInfo( matcher.create_match(context.parsed_file, candidate), m.bindings) return None
class Rebind(matcher.Matcher): """Change the binding settings for all bindings in a submatcher. For example, one might want bindings in one part of the AST matcher to merge with each other, but then want it to be an error if these conflict anywhere else. Args: submatcher: The matcher whose bindings to rewrite. on_conflict: A conflict resolution strategy. Must be a member of :class:`matcher.BindConflict <refex.python.matcher.BindConflict>`, or ``None`` if ``on_conflict`` is not to be changed. on_merge: A merge strategy. Must be a member of :class:`matcher.BindMerge <refex.python.matcher.BindMerge>`, or ``None`` if ``on_merge`` is not to be changed. """ _submatcher = matcher.submatcher_attrib(default=Anything()) _on_conflict = attr.ib(default=None, validator=attr.validators.in_( frozenset(matcher.BindConflict) | {None})) _on_merge = attr.ib( default=None, validator=attr.validators.in_(frozenset(matcher.BindMerge) | {None})) def _match(self, context, candidate): result = self._submatcher.match(context, candidate) if result is None: return None return attr.evolve(result, bindings={ metavar: bind.rebind(on_conflict=self._on_conflict, on_merge=self._on_merge) for metavar, bind in result.bindings.items() }) @cached_property.cached_property def type_filter(self): return self._submatcher.type_filter
class HasParent(matcher.Matcher): """Matches an AST node if its direct parent matches the submatcher. An AST node in this context is considered to be an AST object or a list object. Only direct parents are yielded -- the exact object x s.t. the candidate is ``x.y`` or ``x[y]``, for some ``y``. There is no recursive traversal of any kind. Fails the match if the candidate node is not an AST object or list. """ _submatcher = matcher.submatcher_attrib() def _match(self, context, candidate): parent = context.parsed_file.nav.get_parent(candidate) if parent is None: return None m = self._submatcher.match(context, parent) if m is None: return None return matcher.MatchInfo( matcher.create_match(context.parsed_file, candidate), m.bindings)
class HasItem(matcher.Matcher): """Matches a container iff ``submatcher`` matches ``container[index]``. Fails the match if the container doesn't contain the key, or if the candidate node is not a container at all. """ _index = attr.ib() _submatcher = matcher.submatcher_attrib(default=Anything()) def _match(self, context, candidate): try: sub_candidate = candidate[self._index] except (LookupError, TypeError): return None else: m = self._submatcher.match(context, sub_candidate) if m is None: return None return matcher.MatchInfo( matcher.create_match(context.parsed_file, candidate), m.bindings)
class WithTopLevelImport(matcher.Matcher): """Matches an AST node if there is a top level import for the given module. Args: submatcher: The matcher to filter results from. module_name: The fully-qualified module name as a string. e.g. ``'os.path'``. as_name: The variable name the module is imported as. Defaults to the name one would get from e.g. ``from os import path``. """ # TODO: Would be nice to match on function-local imports as well. # TODO: Would be nice to use submatchers for module_name/as_name. _submatcher = matcher.submatcher_attrib() _module_name = attr.ib() _as_name = attr.ib() @_as_name.default def _as_name_default(self): return self._module_name.rsplit('.', 1)[-1] # per-AST state _ast_imports = weakref.WeakKeyDictionary() @classmethod def _get_ast_imports(cls, tree): if tree not in cls._ast_imports: cls._ast_imports[tree] = _top_level_imports(tree) return cls._ast_imports[tree] def _match(self, context, candidate): imports = self._get_ast_imports(context.parsed_file.tree) if (self._module_name in imports and imports[self._module_name] == self._as_name): return self._submatcher.match(context, candidate) return None @cached_property.cached_property def type_filter(self): return self._submatcher.type_filter
class _SubmatcherAttribsClass(object): submatcher = matcher.submatcher_attrib(default=base_matchers.Anything()) submatcher_list = matcher.submatcher_list_attrib( default=(base_matchers.Anything(), ))
class Bind(matcher.Matcher): """Binds an AST-matcher expression to a name in the result. Args: name: The name to bind to. Valid names must be words that don't begin with a double-underscore (``__``). submatcher: The matcher whose result will be bound to ``name``. on_conflict: A conflict resolution strategy. Must be a member of :class:`matcher.BindConflict <refex.python.matcher.BindConflict>`, or ``None`` for the default strategy (``ACCEPT``). on_merge: A merge strategy. Must be a member of :class:`matcher.BindMerge <refex.python.matcher.BindMerge>`, or None for the default strategy (``KEEP_LAST``). """ _NAME_REGEX = re.compile(r'\A(?!__)[a-zA-Z_]\w*\Z') name = attr.ib() _submatcher = matcher.submatcher_attrib(default=Anything()) _on_conflict = attr.ib(default=None, validator=attr.validators.in_( frozenset(matcher.BindConflict) | {None})) _on_merge = attr.ib( default=None, validator=attr.validators.in_(frozenset(matcher.BindMerge) | {None})) @name.validator def _name_validator(self, attribute, value): if not self._NAME_REGEX.match(value): raise ValueError( "invalid bind name: {value!r} doesn't match {regex}".format( value=value, regex=self._NAME_REGEX)) def _match(self, context, candidate): """Returns the submatcher's match, with a binding introduced by this Bind. Args: context: The match context. candidate: The candidate node to be matched. Returns: An extended :class:`~refex.python.matcher.MatchInfo` with the new binding specified in the constructor. Conflicts are merged according to ``on_conflict``. If there was no match, or on_conflict result in a skip, then this returns ``None``. See matcher.merge_bindings for more details. """ result = self._submatcher.match(context, candidate) if result is None: return None bindings = matcher.merge_bindings( result.bindings, { self.name: matcher.BoundValue(result.match, on_conflict=self._on_conflict, on_merge=self._on_merge) }) if bindings is None: return None return attr.evolve(result, bindings=bindings) @cached_property.cached_property def bind_variables(self): return frozenset([self.name]) | self._submatcher.bind_variables @cached_property.cached_property def type_filter(self): return self._submatcher.type_filter
class StmtFromFunctionPattern(matcher.Matcher): """A StmtPattern, but using a function to define the syntax. Instead of using metavars with `$`, they must be defined in the function arguments. So for example:: def before(foo): foo.bar = 5 matcher = StmtFromFunctionPattern(before) is equivalent to:: StmtPattern('$foo.bar = 5') This makes it much more obvious that patterns like the following will not work as expected:: def before(x): import x FunctionPatterns may, optionally, include a docstring describing what the pattern should match. This will be ignored by the matcher. The name of the function is arbitrary, but metavar names must be defined in the function arguments. FunctionPatterns are resolved using `inspect.getsource`. This leads to a few limitations, importantly the functions used cannot be lambdas, and the matcher will fail (with weird errors) if you attempt to define and use a FromFunction matcher in an interactive session or other situations where source code isn't accessible. """ func = attr.ib() # type: Callable _ast_matcher = matcher.submatcher_attrib( repr=False, init=False, default=attr.Factory( lambda self: self._get_matcher(), # pylint: disable=protected-access takes_self=True), ) # type: matcher.Matcher def _get_matcher(self): """Override of get_matcher to pull things from a function object.""" # `inspect.getsource` doesn't, say, introspect the code object for its # source. Python, despite its dyanamism, doesn't support that much magic. # Instead, it gets the file and line number information from the code # object, and returns those lines as the source. This leads to a few # interesting consequences: # - Functions that exist within a class or closure are by default # `IndentationError`, the code block must be textwrap-dedented before # being used. # - This won't work in interactive modes (shell, ipython, etc.) # - Functions are normally statements, so treating everything from the # first line to the last as part of the function is probably fine. There # are a few cases where this will break, namely # - A lambda will likely be a syntax error, the tool will see # `lambda x: x)`, where `)` is the closing paren of the enclosing # scope. source = textwrap.dedent(inspect.getsource(self.func)) args = _args(self.func) try: parsed = ast.parse(source) except SyntaxError: raise ValueError( 'Function {} appears to have invalid syntax. Is it a' ' lambda?'.format(self.func.__name__)) actual_body = parsed.body[0].body if (isinstance(actual_body[0], ast.Expr) and isinstance(actual_body[0].value, ast.Str)): # Strip the docstring, if it exists. actual_body = actual_body[1:] if not actual_body: raise ValueError('Format function must include an actual body, a ' 'docstring alone is invalid.') if isinstance(actual_body[0], ast.Pass): raise ValueError( 'If you *really* want to rewrite a function whose body ' 'is just `pass`, use a regex replacer.') # Since we don't need to mangle names, we just generate bindings. bindings = {} for name in args: bindings[name] = base_matchers.Bind( name, base_matchers.Anything(), on_conflict=matcher.BindConflict.MERGE_EQUIVALENT_AST) return base_matchers.Rebind( _ast_pattern(actual_body[0], bindings), on_conflict=matcher.BindConflict.MERGE, on_merge=matcher.BindMerge.KEEP_LAST, ) def _match(self, context, candidate): return self._ast_matcher.match(context, candidate)