def test_extract_precedence_sequence(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[ m.Arg(m.SaveMatchedNode(m.DoNotCare(), "arg")), m.Arg(m.SaveMatchedNode(m.DoNotCare(), "arg")), ])), ]), ) extracted_node = (cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).args[1].value) self.assertEqual(nodes, {"arg": extracted_node})
def visit_With(self, node: cst.With): with_names = [ n.value for n in match.findall( match.extract( node, match.With(items=match.SaveMatchedNode( match.DoNotCare(), 'with_items')))['with_items'][0], match.Name( value=match.SaveMatchedNode(match.DoNotCare(), 'name'))) ] if len(self.stack) > 0: self.fn_may_args_var_use.append(with_names) if len(self.cls_stack) > 0: if self.cls_stack[0].name in with_names: self.cls_may_vars_use.append(with_names) self.__find_module_vars_use(with_names)
def test_extract_optional_wildcard_present(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[ m.DoNotCare(), m.DoNotCare(), m.ZeroOrOne( m.Arg(m.SaveMatchedNode(m.Attribute(), "arg"))), ])), ]), ) extracted_node = (cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).args[2].value) self.assertEqual(nodes, {"arg": extracted_node})
def test_extract_sequence_multiple_wildcards(self) -> None: expression = cst.parse_expression("1, 2, 3, 4") nodes = m.extract( expression, m.Tuple(elements=( m.SaveMatchedNode(m.ZeroOrMore(), "head"), m.SaveMatchedNode(m.Element(value=m.Integer( value="3")), "element"), m.SaveMatchedNode(m.ZeroOrMore(), "tail"), )), ) tuple_elements = cst.ensure_type(expression, cst.Tuple).elements self.assertEqual( nodes, { "head": tuple(tuple_elements[:2]), "element": tuple_elements[2], "tail": tuple(tuple_elements[3:]), }, )
def visit_ImportAlias(self, node: cst.ImportAlias): """ Extracts imports. Even if an Import has no alias, the Import node will still have an ImportAlias with it's real name. """ # Extract information from the Import Alias node. # Both the (real) import name and the alias are extracted. # Result is returned as a dictionary with a name:node KV pair, # and an alias:name KV pair if an alias is present. import_info = match.extract( node, match.ImportAlias( asname=match.AsName( # Attempt to match alias name=match.Name( # Check for alias name value=match.SaveMatchedNode( # Save the alias name match.MatchRegex( r'(.)+'), # Match any string literal "alias"))) | ~match.AsName(), # If there is no AsName, we should still match. We negate AsName to and OR to form a tautology. name=match.SaveMatchedNode( # Match & save name of import match.DoNotCare(), "name"))) # Add import if a name could be extracted. if "name" in import_info: # Append import name to imports list. # We convert the ImportAlias node to the code representation for simplified conversion # of multi-level imports (e.g. import.x.y.z) # TODO: This might be un-needed after implementing import type extraction change. # TODO: So far, no differentiation between import and from imports. import_name = self.__convert_node_to_code(import_info["name"]) import_name = self.__clean_string_whitespace(import_name) self.imports.append(import_name) if "alias" in import_info: import_name = self.__clean_string_whitespace(import_info["alias"]) self.imports.append(import_name) # No need to traverse children, as we already performed the matching. return False
def test_extract_multiple(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element(m.Call(func=m.SaveMatchedNode(m.Name(), "func"))), ]), ) extracted_node_left = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left extracted_node_func = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).func self.assertEqual(nodes, { "left": extracted_node_left, "func": extracted_node_func })
def __extract_variable_name(self, node: cst.AssignTarget): extracted_var_names = match.extract( node, match.AssignTarget( # Assignment operator target=match.OneOf( # Two cases exist match.Name( # Single target value=match.SaveMatchedNode( # Save result match.MatchRegex( r'(.)+'), # Match any string literal "name")), match.Tuple( # Multi-target elements=match.SaveMatchedNode( # Save result match.DoNotCare(), # Type of list "names")), # This extracts variables inside __init__ without type annotation (e.g. self.x=2) match.Attribute( value=match.Name(value=match.SaveMatchedNode( match.MatchRegex(r'(.)+'), "obj_name" # Object name )), attr=match.Name( match.SaveMatchedNode(match.MatchRegex(r'(.)+'), "name")), )))) if extracted_var_names is not None: if "name" in extracted_var_names: t = self.__get_type_from_metadata(node.target) extracted_var_names['type'] = (t, INF_TYPE_ANNOT if t else UNK_TYPE_ANNOT) return extracted_var_names elif "names" in extracted_var_names: return { 'names': self.__extract_names_multi_assign( list(extracted_var_names['names'])) } else: return extracted_var_names
def _is_awaitable_callable(annotation: str) -> bool: if not (annotation.startswith("typing.Callable") or annotation.startswith("typing.ClassMethod") or annotation.startswith("StaticMethod")): # Exit early if this is not even a `typing.Callable` annotation. return False try: # Wrap this in a try-except since the type annotation may not be parse-able as a module. # If it is not parse-able, we know it's not what we are looking for anyway, so return `False`. parsed_ann = cst.parse_module(annotation) except Exception: return False # If passed annotation does not match the expected annotation structure for a `typing.Callable` with # typing.Coroutine as the return type, matched_callable_ann will simply be `None`. # The expected structure of an awaitable callable annotation from Pyre is: typing.Callable()[[...], typing.Coroutine[...]] matched_callable_ann: Optional[Dict[str, Union[ Sequence[cst.CSTNode], cst.CSTNode]]] = m.extract( parsed_ann, m.Module(body=[ m.SimpleStatementLine(body=[ m.Expr(value=m.Subscript(slice=[ m.SubscriptElement(), m.SubscriptElement(slice=m.Index(value=m.Subscript( value=m.SaveMatchedNode( m.Attribute(), "base_return_type", )))), ], )) ]), ]), ) if (matched_callable_ann is not None and "base_return_type" in matched_callable_ann): base_return_type = get_full_name_for_node( cst.ensure_type(matched_callable_ann["base_return_type"], cst.CSTNode)) return (base_return_type is not None and base_return_type == "typing.Coroutine") return False
def __extract_assign_newtype(self, node: cst.Assign): """ Attempts extracting a NewType declaration from the provided Assign node. If the Assign node corresponds to a NewType assignment, the NewType name is added to the class definitions of the Visitor. """ # Define matcher to extract NewType assignment matcher_newtype = match.Assign( targets=[ # Check the assign targets match.AssignTarget( # There should only be one target target=match.Name( # Check target name value=match.SaveMatchedNode( # Save target name match.MatchRegex( r'(.)+'), # Match any string literal "type"))) ], value=match.Call( # We are examining a function call func=match.Name( # Function must have a name value="NewType" # Name must be 'NewType' ), args=[ match.Arg( # Check first argument value=match.SimpleString( ) # First argument must be the name for the type ), match.ZeroOrMore( ) # We allow any number of arguments after by def. of NewType ])) extracted_type = match.extract(node, matcher_newtype) if extracted_type is not None: # Append the additional type to the list # TODO: Either rename class defs, or create new list for additional types self.class_defs.append(extracted_type["type"].strip("\'"))
def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None: expr_key = "expr" extracts = m.extract( node, m.BinaryOperation( left=m.MatchIfTrue(_match_simple_string), operator=m.Modulo(), right=m.SaveMatchedNode( m.MatchIfTrue( _gen_match_simple_expression( self.context.wrapper.module)), expr_key, ), ), ) if extracts: expr = extracts[expr_key] parts = [] simple_string = cst.ensure_type(node.left, cst.SimpleString) innards = simple_string.raw_value.replace("{", "{{").replace("}", "}}") tokens = innards.split("%s") token = tokens[0] if len(token) > 0: parts.append(cst.FormattedStringText(value=token)) expressions = ([elm.value for elm in expr.elements] if isinstance( expr, cst.Tuple) else [expr]) escape_transformer = EscapeStringQuote(simple_string.quote) i = 1 while i < len(tokens): if i - 1 >= len(expressions): # Only generate warning for cases where %-string not comes with same number of elements in tuple self.report(node) return try: parts.append( cst.FormattedStringExpression(expression=cast( cst.BaseExpression, expressions[i - 1].visit(escape_transformer), ))) except Exception: self.report(node) return token = tokens[i] if len(token) > 0: parts.append(cst.FormattedStringText(value=token)) i += 1 start = f"f{simple_string.prefix}{simple_string.quote}" replacement = cst.FormattedString(parts=parts, start=start, end=simple_string.quote) self.report(node, replacement=replacement) elif m.matches( node, m.BinaryOperation( left=m.SimpleString(), operator=m.Modulo())) and isinstance( cst.ensure_type( node.left, cst.SimpleString).evaluated_value, str): self.report(node)
def visit_Call(self, node: cst.Call) -> None: result = m.extract( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.DoNotCare(), m.Arg(value=m.SaveMatchedNode( m.OneOf( m.Integer(), m.Float(), m.Imaginary(), m.Tuple(), m.List(), m.Set(), m.Dict(), m.Name("None"), m.Name("True"), m.Name("False"), ), "second", )), ], ), ) if result: second_arg = result["second"] if isinstance(second_arg, Sequence): second_arg = second_arg[0] if m.matches(second_arg, m.Name("True")): new_call = node.with_changes(args=[ node.args[0].with_changes(comma=cst.MaybeSentinel.DEFAULT) ], ) elif m.matches(second_arg, m.Name("None")): new_call = node.with_changes( func=node.func.with_deep_changes( old_node=cst.ensure_type(node.func, cst.Attribute).attr, value="assertIsNone", ), args=[ node.args[0].with_changes( comma=cst.MaybeSentinel.DEFAULT) ], ) elif m.matches(second_arg, m.Name("False")): new_call = node.with_changes( func=node.func.with_deep_changes( old_node=cst.ensure_type(node.func, cst.Attribute).attr, value="assertFalse", ), args=[ node.args[0].with_changes( comma=cst.MaybeSentinel.DEFAULT) ], ) else: new_call = node.with_deep_changes( old_node=cst.ensure_type(node.func, cst.Attribute).attr, value="assertEqual", ) self.report(node, replacement=new_call)
def leave_BinaryOperation( self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation ) -> cst.BaseExpression: expr_key = "expr" extracts = m.extract( original_node, m.BinaryOperation( # pyre-fixme[6]: Expected `Union[m._matcher_base.AllOf[typing.Union[m... left=m.MatchIfTrue(_match_simple_string), operator=m.Modulo(), # pyre-fixme[6]: Expected `Union[m._matcher_base.AllOf[typing.Union[m... right=m.SaveMatchedNode( m.MatchIfTrue(_gen_match_simple_expression(self.module)), expr_key, ), ), ) if extracts: exprs = extracts[expr_key] exprs = (exprs,) if not isinstance(exprs, Sequence) else exprs parts = [] simple_string = cst.ensure_type(original_node.left, cst.SimpleString) innards = simple_string.raw_value.replace("{", "{{").replace("}", "}}") tokens = innards.split("%s") token = tokens[0] if len(token) > 0: parts.append(cst.FormattedStringText(value=token)) expressions: List[cst.CSTNode] = list( *itertools.chain( [elm.value for elm in expr.elements] if isinstance(expr, cst.Tuple) else [expr] for expr in exprs ) ) escape_transformer = EscapeStringQuote(simple_string.quote) i = 1 while i < len(tokens): if i - 1 >= len(expressions): # the %-string doesn't come with same number of elements in tuple return original_node try: parts.append( cst.FormattedStringExpression( expression=cast( cst.BaseExpression, expressions[i - 1].visit(escape_transformer), ) ) ) except Exception: return original_node token = tokens[i] if len(token) > 0: parts.append(cst.FormattedStringText(value=token)) i += 1 start = f"f{simple_string.prefix}{simple_string.quote}" return cst.FormattedString( parts=parts, start=start, end=simple_string.quote ) return original_node
def visit_Call(self, node: cst.Call) -> None: match_compare_is_none = m.ComparisonTarget( m.SaveMatchedNode( m.OneOf(m.Is(), m.IsNot()), "comparison_type", ), comparator=m.Name("None"), ) result = m.extract( node, m.Call( func=m.Attribute( value=m.Name("self"), attr=m.SaveMatchedNode( m.OneOf(m.Name("assertTrue"), m.Name("assertFalse")), "assertion_name", ), ), args=[ m.Arg( m.SaveMatchedNode( m.OneOf( m.Comparison( comparisons=[match_compare_is_none]), m.UnaryOperation( operator=m.Not(), expression=m.Comparison( comparisons=[match_compare_is_none]), ), ), "argument", )) ], ), ) if result: assertion_name = result["assertion_name"] if isinstance(assertion_name, Sequence): assertion_name = assertion_name[0] argument = result["argument"] if isinstance(argument, Sequence): argument = argument[0] comparison_type = result["comparison_type"] if isinstance(comparison_type, Sequence): comparison_type = comparison_type[0] if m.matches(argument, m.Comparison()): assertion_argument = ensure_type(argument, cst.Comparison).left else: assertion_argument = ensure_type( ensure_type(argument, cst.UnaryOperation).expression, cst.Comparison).left negations_seen = 0 if m.matches(assertion_name, m.Name("assertFalse")): negations_seen += 1 if m.matches(argument, m.UnaryOperation()): negations_seen += 1 if m.matches(comparison_type, m.IsNot()): negations_seen += 1 new_attr = "assertIsNone" if negations_seen % 2 == 0 else "assertIsNotNone" new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name(new_attr)), args=[cst.Arg(assertion_argument)], ) if new_call is not node: self.report(node, replacement=new_call)