def test_no_sibling(self, expr): self.assertEmpty( self.get_all_match_strings( base_matchers.AllOf( ast_matchers.Name(), syntax_matchers.HasNextSibling(ast_matchers.Name())), expr))
def test_compound_sibling(self): self.assertEqual( self.get_all_match_strings( base_matchers.AllOf( ast_matchers.BinOp(), syntax_matchers.HasNextSibling(ast_matchers.BinOp())), '[a+c, a+b]'), ['a+c'])
def test_expression_sibling(self, expr): self.assertEqual( self.get_all_match_strings( base_matchers.AllOf( ast_matchers.Name(), syntax_matchers.HasNextSibling(ast_matchers.Name())), expr), ['a', 'b'])
def test_has_hasparent(self, matcher_type): m = base_matchers.AllOf( base_matchers.Unless(ast_matchers.Add()), # Add can't stringify. matcher_type(ast_matchers.BinOp())) # even for IsOrHasAncestor, find_iter doesn't recurse into a matched node. self.assertEqual(self.get_all_match_strings(m, '1 + (2 + 3)'), ['1', '2 + 3'])
def test_type_filter_nonempty_disjoint(self): self.assertEqual( base_matchers.AllOf( base_matchers.TypeIs(int), base_matchers.TypeIs(float), ).type_filter, frozenset({}), )
def test_matches(self): parsed = matcher.parse_ast('var_hello = 42', '<string>') m = base_matchers.AllOf(base_matchers.FileMatchesRegex('hello'), ast_matchers.Num()) matches = list(matcher.find_iter(m, parsed)) self.assertLen(matches, 1)
def test_doesnt_match(self): parsed = matcher.parse_ast('hi = 42', '<string>') m = base_matchers.AllOf(base_matchers.FileMatchesRegex('hello'), ast_matchers.Num()) matches = list(matcher.find_iter(m, parsed)) self.assertEqual(matches, [])
def test_type_filter_skipped_micro(self): """Matchers are skipped if they do not match the type filter.""" m = base_matchers.AnyOf( base_matchers.AllOf( base_matchers.TestOnlyRaise('this should be skipped'), base_matchers.TypeIs(float)), base_matchers.TypeIs(int), ) self.assertIsNotNone(m.match(_FAKE_CONTEXT, 4))
def test_multi_overlap(self): # TODO: it'd be nice to give a good error at some point, instead. self.assertEqual( base_matchers.AllOf(base_matchers.Bind('foo'), base_matchers.Bind('foo')).match( _FAKE_CONTEXT, 1), matcher.MatchInfo( match.ObjectMatch(1), {'foo': matcher.BoundValue(match.ObjectMatch(1))}))
def test_cached(self): source = 'b\na\nb\n' m = base_matchers.AllOf( base_matchers.MatchesRegex(r'a|b'), base_matchers.Once(base_matchers.MatchesRegex(r'a')), ) self.assertEqual( self.get_all_match_strings(m, source), ['a', 'b'], )
def test_multi_bind(self): self.assertEqual( base_matchers.AllOf(base_matchers.Bind('foo'), base_matchers.Bind('bar')).match( _FAKE_CONTEXT, 1), matcher.MatchInfo( match.ObjectMatch(1), { 'foo': matcher.BoundValue(match.ObjectMatch(1)), 'bar': matcher.BoundValue(match.ObjectMatch(1)), }))
def test_bindings(self): parsed = matcher.parse_ast('x = 2', '<string>') matches = list( matcher.find_iter( base_matchers.AllOf( base_matchers.FileMatchesRegex(r'(?P<var>x)'), ast_matchers.Num()), parsed)) self.assertLen(matches, 1) [m] = matches self.assertIn('var', m.bindings) self.assertEqual(m.bindings['var'].value.span, (0, 1))
def test_variable_conflict(self): """Variables use the default conflict resolution outside of the pattern. Inside of the pattern, they use MERGE_EQUIVALENT_AST, but this is opaque to callers. """ # The AllOf shouldn't make a difference, because the $x variable is just # a regular Bind() variable outside of the pattern, and merges via KEEP_LAST # per normal. self.assertEqual( self.get_all_match_strings( base_matchers.AllOf(syntax_matchers.StmtPattern('$x'), base_matchers.Bind('x')), '1'), ['1'])
def test_multi_regex(self): """Tests that the lazy dictionary doesn't walk over itself or something.""" parsed = matcher.parse_ast('var_hello = 42', '<string>') m = base_matchers.AllOf( base_matchers.FileMatchesRegex('var'), base_matchers.FileMatchesRegex('hello'), base_matchers.FileMatchesRegex('42'), base_matchers.FileMatchesRegex(r'\Avar_hello = 42\Z'), ast_matchers.Num()) matches = list(matcher.find_iter(m, parsed)) self.assertLen(matches, 1)
def test_matches_only_immediate_siblings(self): self.assertEqual( self.get_all_match_strings( base_matchers.AllOf( ast_matchers.Assign(), syntax_matchers.HasNextSibling(ast_matchers.ClassDef())), textwrap.dedent("""\ class Before: pass a = 1 b = 2 class After: pass if a: c = 3 else: class AlsoBefore: pass d = 4 """)), ['b = 2'])
def test_in_named_function(self): self.assertEqual( self.get_all_match_strings( base_matchers.AllOf( ast_matchers.Name(), syntax_matchers.InNamedFunction( ast_matchers.FunctionDef(name='foo'))), textwrap.dedent("""\ in_nothing def parent(): in_parent def foo(): in_foo def foo_nested(): in_foo_nested def bar(): in_bar """), ), ['in_foo'])
def _match(self, context, candidate): if not isinstance(candidate, ast.AST): return None # Walk the AST to collect the answer: values = [] for node in ast.walk(candidate): # Every node must either be a Constant/Num or an addition node. if isinstance(node, ast.Constant): values.append(node.value) elif isinstance(node, ast.Num): # older pythons values.append(node.n) elif isinstance(node, ast.BinOp) or isinstance(node, ast.Add): # Binary operator nodes are allowed, but only if they have an Add() op. pass else: return None # not a +, not a constant # For more complex tasks, or for tasks which integrate into how Refex # builds results and bindings, it can be helpful to defer work into a # submatcher, such as by running BinOp(op=Add()).match(context, candidate) # Having walked the AST, we have determined that the whole tree is addition # of constants, and have collected all of those constants in a list. if len(values) <= 1: # Don't bother emitting a replacement for e.g. 7 with itself. return None result = str(sum(values)) # Finally, we want to return the answer to Refex: # 1) bind the result to a variable # 2) return the tree itself as the matched value # We can do this by deferring to a matcher that does the right thing. # StringMatch() will produce a string literal match, and AllOf will retarget # the returned binding to the AST node which was passed in. submatcher = base_matchers.AllOf( base_matchers.Bind("sum", base_matchers.StringMatch(result))) return submatcher.match(context, candidate)
syntax_matchers.ExprPattern('None'))))) _NONE_RETURNS_FIXERS = [ fixer.SimplePythonFixer( message= 'If a function ever returns a value, all the code paths should have a return statement with a return value.', url= 'https://refex.readthedocs.io/en/latest/guide/fixers/return_none.html', significant=False, category=_NONE_RETURNS_CATEGORY, matcher=base_matchers.AllOf( syntax_matchers.StmtPattern('return'), syntax_matchers.InNamedFunction( _function_containing(_NON_NONE_RETURN)), # Nested functions are too weird to consider right now. # TODO: Add matchers to match only the first ancestor # function and a way to use IsOrHasDescendant that doesn't recurse # into nested functions. base_matchers.Unless( syntax_matchers.InNamedFunction( _function_containing( syntax_matchers.NamedFunctionDefinition())))), replacement=syntactic_template.PythonStmtTemplate('return None'), example_fragment=textwrap.dedent(""" def f(x): if x: return return -1 """), example_replacement=textwrap.dedent(""" def f(x): if x:
def test_nonchild_descendant_haschild(self): m = base_matchers.AllOf(ast_matchers.Call(), syntax_matchers.HasChild(ast_matchers.Num())) self.assertEqual(self.get_all_match_strings(m, 'foo(x + 1)'), [])
def test_nonparent_ancestor(self, matcher_type): m = base_matchers.AllOf(ast_matchers.Num(), matcher_type(ast_matchers.Call())) self.assertEqual(self.get_all_match_strings(m, 'foo(x + 1)'), ['1'])
def test_nonparent_ancestor_hasparent(self): m = base_matchers.AllOf(ast_matchers.Num(), syntax_matchers.HasParent(ast_matchers.Call())) self.assertEqual(self.get_all_match_strings(m, 'foo(x + 1)'), [])
def test_multi_bind_fail(self): self.assertIsNone( base_matchers.AllOf(base_matchers.Bind('foo', _NOTHING), base_matchers.Bind('bar', _NOTHING)).match( _FAKE_CONTEXT, 1))
def test_empty(self): self.assertEqual(base_matchers.AllOf().match(_FAKE_CONTEXT, 1), matcher.MatchInfo(match.ObjectMatch(1)))
def test_type_filter_empty(self): self.assertIsNone(base_matchers.AllOf().type_filter)
def test_nonchild_descendant(self, matcher_type): m = base_matchers.AllOf(ast_matchers.Call(), matcher_type(ast_matchers.Num())) self.assertEqual(self.get_all_match_strings(m, 'foo(x + 1)'), ['foo(x + 1)'])