class NamedFunctionDefinition(matcher.Matcher): """A matcher for a named function definition. This includes both regular functions and async functions. Args: body: The matcher for the function body. returns: The matcher for the return type annotation. """ _body = matcher.submatcher_attrib(default=base_matchers.Anything()) _returns = matcher.submatcher_attrib(default=base_matchers.Anything()) @cached_property.cached_property def _matcher(self): kwargs = {'body': self._body} # We check for the existence of `returns` as an AST field, instead of # checking the Python version, to support backports of the type annotation # syntax to Python 2. if 'returns' in attr.fields_dict(ast_matchers.FunctionDef): kwargs['returns'] = self._returns function_def = ast_matchers.FunctionDef(**kwargs) if six.PY3: function_def = base_matchers.AnyOf( ast_matchers.AsyncFunctionDef(**kwargs), function_def, ) return function_def def _match(self, context, candidate): return self._matcher.match(context, candidate) @cached_property.cached_property def type_filter(self): return self._matcher.type_filter
def test_explicit_anything(self): parsed, e = expression('~a') self.assertEqual( ast_matchers.UnaryOp( op=base_matchers.Anything(), operand=base_matchers.Anything()).match( matcher.MatchContext(parsed), e), matcher.MatchInfo( matcher.LexicalASTMatch(e, parsed.text, e.first_token, e.last_token)))
def test_key(self): shared = 'shared' once_any = base_matchers.Once(base_matchers.Anything(), key=shared) once_never = base_matchers.Once(base_matchers.Unless( base_matchers.Anything()), key=shared) # once_never reuses the cached result of once_any, because they share a key. m = base_matchers.ItemsAre([once_any, once_never]) self.assertIsNotNone(m.match(_FAKE_CONTEXT.new(), [1, 2]))
def test_nomatch(self): source = 'a\nb\n' m = base_matchers.Once(base_matchers.Unless(base_matchers.Anything())) self.assertEqual( self.get_all_match_strings(m, source), [], )
def test_ancestor(self): """The matcher won't traverse into child nodes.""" parsed = matcher.parse_ast('~a', '<string>') self.assertIsNone( ast_matchers.UnaryOp( op=base_matchers.Unless(base_matchers.Anything())).match( matcher.MatchContext(parsed), parsed.tree.body[0]))
def test_bind_2arg(self): self.assertEqual( base_matchers.Bind('foo', base_matchers.Anything()).match( _FAKE_CONTEXT, 1), matcher.MatchInfo( match.ObjectMatch(1), {'foo': matcher.BoundValue(match.ObjectMatch(1))}))
class Str(matcher.Matcher): s = matcher.submatcher_attrib(default=base_matchers.Anything()) def _match(self, context, candidate): return _constant_match(context, candidate, self.s, (str, )) type_filter = frozenset({ast.Constant})
def test_renamed_import(self): any_m = base_matchers.Anything() m_success = syntax_matchers.WithTopLevelImport(any_m, 'os', 'renamed') m_fail = syntax_matchers.WithTopLevelImport(any_m, 'os') context = matcher.MatchContext( matcher.parse_ast('import os as renamed')) self.assertIsNotNone(m_success.match(context, 1)) self.assertIsNone(m_fail.match(context, 1))
def test_renamed_fromimport(self, import_stmt): any_m = base_matchers.Anything() m_success = syntax_matchers.WithTopLevelImport(any_m, 'os.path', 'renamed') m_fail = syntax_matchers.WithTopLevelImport(any_m, 'os.path') context = matcher.MatchContext(matcher.parse_ast(import_stmt)) self.assertIsNotNone(m_success.match(context, 1)) self.assertIsNone(m_fail.match(context, 1))
class NameConstant(matcher.Matcher): value = matcher.submatcher_attrib(default=base_matchers.Anything()) def _match(self, context, candidate): return _constant_match(context, candidate, self.value, (bool, type(None))) type_filter = frozenset({ast.Constant})
class Num(matcher.Matcher): n = matcher.submatcher_attrib(default=base_matchers.Anything()) def _match(self, context, candidate): return _constant_match(context, candidate, self.n, (int, float, complex)) type_filter = frozenset({ast.Constant})
def test_inner_nomatches(self): parsed = matcher.parse_ast('xy = 2', '<string>') matches = list( matcher.find_iter( base_matchers.MatchesRegex( r'', base_matchers.Unless(base_matchers.Anything())), parsed)) self.assertEqual(matches, [])
def test_type_filter_ordered(self): """Tests that type optimizations don't mess with matcher order.""" m = base_matchers.AnyOf( base_matchers.Bind('a', base_matchers.Anything()), base_matchers.Bind('b', base_matchers.TypeIs(int)), ) self.assertEqual( m.match(_FAKE_CONTEXT, 4).bindings.keys(), {'a'}, )
def test_has_nested_import(self, import_stmt): any_m = base_matchers.Anything() matchers = [ syntax_matchers.WithTopLevelImport(any_m, 'os.path'), syntax_matchers.WithTopLevelImport(any_m, 'os.path', 'path'), ] context = matcher.MatchContext(matcher.parse_ast(import_stmt)) for m in matchers: with self.subTest(m=m): self.assertIsNotNone(m.match(context, 1))
class Subscript(Subscript): # pylint: disable=undefined-variable slice = matcher.submatcher_attrib(default=base_matchers.Anything()) @slice.validator def _slice_validator(self, attribute, value): del attribute # unused if isinstance(value, base_matchers.Bind): raise ValueError( 'slice=Bind(...) not supported in Python < 3.9. It will fail to ' 'correctly match e.g. `a[:]` or `a[1,:]`. Upgrade to Python 3.9, or' ' work around this using AllOf(Bind(...)) if that is OK.')
def _get_matcher(self): """Override of get_matcher to pull things from a function object.""" # `inspect.getsource` doesn't, say, introspect the code object for its # source. Python, despite its dyanamism, doesn't support that much magic. # Instead, it gets the file and line number information from the code # object, and returns those lines as the source. This leads to a few # interesting consequences: # - Functions that exist within a class or closure are by default # `IndentationError`, the code block must be textwrap-dedented before # being used. # - This won't work in interactive modes (shell, ipython, etc.) # - Functions are normally statements, so treating everything from the # first line to the last as part of the function is probably fine. There # are a few cases where this will break, namely # - A lambda will likely be a syntax error, the tool will see # `lambda x: x)`, where `)` is the closing paren of the enclosing # scope. source = textwrap.dedent(inspect.getsource(self.func)) args = _args(self.func) try: parsed = ast.parse(source) except SyntaxError: raise ValueError( 'Function {} appears to have invalid syntax. Is it a' ' lambda?'.format(self.func.__name__)) actual_body = parsed.body[0].body if (isinstance(actual_body[0], ast.Expr) and isinstance(actual_body[0].value, ast.Str)): # Strip the docstring, if it exists. actual_body = actual_body[1:] if not actual_body: raise ValueError('Format function must include an actual body, a ' 'docstring alone is invalid.') if isinstance(actual_body[0], ast.Pass): raise ValueError( 'If you *really* want to rewrite a function whose body ' 'is just `pass`, use a regex replacer.') # Since we don't need to mangle names, we just generate bindings. bindings = {} for name in args: bindings[name] = base_matchers.Bind( name, base_matchers.Anything(), on_conflict=matcher.BindConflict.MERGE_EQUIVALENT_AST) return base_matchers.Rebind( _ast_pattern(actual_body[0], bindings), on_conflict=matcher.BindConflict.MERGE, on_merge=matcher.BindMerge.KEEP_LAST, )
def _generate_syntax_matcher(cls, ast_node_type): # Generate a class with an attrs field for every AST field, passed by # keyword argument only. ty = attr.make_class( ast_node_type.__name__, { field: matcher.submatcher_attrib( default=base_matchers.Anything(), ) for field in ast_node_type._fields }, bases=(cls, ), frozen=True, kw_only=True, ) ty._ast_type = ast_node_type # pylint: disable=protected-access ty.type_filter = frozenset({ast_node_type}) return ty
def _matchers_for_matches(matches): """Returns AST matchers for all expressions in `matches`. Args: matches: A mapping of variable name -> match Returns: A mapping of <variable name> -> <matcher that must match>. Only variables that can be parenthesized are in this mapping, and the matcher must match where those variables are substituted in. """ matchers = {} for k, v in matches.items(): if (isinstance(v, matcher.LexicalASTMatch) and isinstance(v.matched, ast.expr)): matchers[k] = syntax_matchers.ast_matchers_matcher(v.matched) else: # as a fallback, treat it as a black box, and assume that the rest of the # expression will catch things. matchers[k] = base_matchers.Anything() return matchers
def _rewrite_submatchers(pattern, restrictions): """Rewrites pattern/restrictions to erase metasyntactic variables. Args: pattern: a pattern containing $variables. restrictions: a dictionary of variables to submatchers. If a variable is missing, Anything() is used instead. Returns: (remapped_pattern, variables, new_submatchers) * remapped_pattern has all variables replaced with new unique names that are valid Python syntax. * variables is the mapping of the original name to the remapped name. * new_submatchers is a dict from remapped names to submatchers. Every variable is put in a Bind() node, which has a submatcher taken from `restrictions`. Raises: KeyError: if restrictions has a key that isn't a variable name. """ pattern, variables = _remap_macro_variables(pattern) incorrect_variables = set(restrictions) - set(variables) if incorrect_variables: raise KeyError( 'Some variables specified in restrictions were missing. ' 'Did you misplace a "$"? Missing variables: %r' % incorrect_variables) submatchers = {} for old_name, new_name in variables.items(): submatchers[new_name] = base_matchers.Bind( old_name, restrictions.get(old_name, base_matchers.Anything()), on_conflict=matcher.BindConflict.MERGE_EQUIVALENT_AST, ) return pattern, variables, submatchers
def test_missing_import(self, import_stmt): any_m = base_matchers.Anything() m = syntax_matchers.WithTopLevelImport(any_m, 'os.path') context = matcher.MatchContext(matcher.parse_ast(import_stmt)) self.assertIsNone(m.match(context, 1))
import ast from unittest import mock from absl.testing import absltest from absl.testing import parameterized from six.moves import range from refex import match from refex.python import evaluate from refex.python import matcher from refex.python import matcher_test_util from refex.python.matchers import ast_matchers from refex.python.matchers import base_matchers _NOTHING = base_matchers.Unless(base_matchers.Anything()) _FAKE_CONTEXT = matcher.MatchContext(matcher.parse_ast('', 'foo.py')) class BindTest(absltest.TestCase): def test_bind_name_invalid(self): with self.assertRaises(ValueError): base_matchers.Bind('__foo') def test_systembind_name_valid(self): base_matchers.SystemBind('__foo') def test_systembind_name_invalid(self): with self.assertRaises(ValueError): base_matchers.SystemBind('foo')
def test_contains_wrongtype(self): """It's useful to run a Contains() check against arbitrary objects.""" m = base_matchers.Contains(base_matchers.Anything()) self.assertIsNone(m.match(_FAKE_CONTEXT, object()))
class _SubmatcherAttribsClass(object): submatcher = matcher.submatcher_attrib(default=base_matchers.Anything()) submatcher_list = matcher.submatcher_list_attrib( default=(base_matchers.Anything(), ))
def test_anything(self): self.assertEqual(base_matchers.Anything().match(_FAKE_CONTEXT, 1), matcher.MatchInfo(match.ObjectMatch(1)))
def test_submatcher_fail(self): parsed, e = expression('~a') self.assertIsNone( ast_matchers.UnaryOp( op=base_matchers.Unless(base_matchers.Anything())).match( matcher.MatchContext(parsed), e))
def test_submatcher_wrong(self): self.assertIsNone( base_matchers.ItemsAre([ base_matchers.Unless(base_matchers.Anything()) ]).match(_FAKE_CONTEXT, [1]))
def test_too_short(self): self.assertIsNone( base_matchers.ItemsAre([base_matchers.Anything() ]).match(_FAKE_CONTEXT, []))
def test_wrongtype(self): m = base_matchers.HasItem(0, base_matchers.Anything()) self.assertIsNone(m.match(_FAKE_CONTEXT, object()))
def test_submatch_rejects(self): self.assertIsNone( base_matchers.HasItem( -1, base_matchers.Unless(base_matchers.Anything())).match( _FAKE_CONTEXT, ['xyz']))
def test_unless_bindings(self): unless_bind = base_matchers.Unless( base_matchers.Bind('name', base_matchers.Anything())) self.assertEqual(unless_bind.bind_variables, set())