def test_extract_simple(self) -> None: # Verify true behavior expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element(m.Call()), ]), ) extracted_node = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left self.assertEqual(nodes, {"left": extracted_node}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Subscript(), "left"))), m.Element(m.Call()), ]), ) self.assertIsNone(nodes)
def test_extract_sequence_element(self) -> None: # Verify true behavior expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[m.SaveMatchedNode(m.ZeroOrMore(), "args")])), ]), ) extracted_seq = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).args self.assertEqual(nodes, {"args": extracted_seq}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[ m.SaveMatchedNode(m.ZeroOrMore(m.Arg(m.Subscript())), "args") ])), ]), ) self.assertIsNone(nodes)
def __name2annotation(self, type_name: str): """ Converts Name nodes to valid annotation nodes """ try: return match.extract( cst.parse_module("x: %s=None" % type_name).body[0].body[0], match.AnnAssign(target=match.Name(value=match.DoNotCare()), annotation=match.SaveMatchedNode( match.DoNotCare(), "type")))['type'] except cst._exceptions.ParserSyntaxError: # To handle a bug in LibCST's scope provider where a local name shadows a type annotation with the same name if (self.last_visited_name.value, cst.metadata.QualifiedNameSource.IMPORT ) in self.q_names_cache: return match.extract( cst.parse_module( "x: %s=None" % self.q_names_cache[(self.last_visited_name.value, cst.metadata.QualifiedNameSource. IMPORT)]).body[0].body[0], match.AnnAssign(target=match.Name(value=match.DoNotCare()), annotation=match.SaveMatchedNode( match.DoNotCare(), "type")))['type'] else: return match.extract( cst.parse_module( "x: %s=None" % self.last_visited_name.value).body[0].body[0], match.AnnAssign(target=match.Name(value=match.DoNotCare()), annotation=match.SaveMatchedNode( match.DoNotCare(), "type")))['type']
def handle_any_string( self, node: Union[cst.SimpleString, cst.ConcatenatedString]) -> None: value = node.evaluated_value if value is None: return mod = cst.parse_module(value) extracted_nodes = m.extractall( mod, m.Name( value=m.SaveMatchedNode(m.DoNotCare(), "name"), metadata=m.MatchMetadataIfTrue( cst.metadata.ParentNodeProvider, lambda parent: not isinstance(parent, cst.Attribute), ), ) | m.SaveMatchedNode(m.Attribute(), "attribute"), metadata_resolver=MetadataWrapper(mod, unsafe_skip_copy=True), ) names = { cast(str, values["name"]) for values in extracted_nodes if "name" in values } | { name for values in extracted_nodes if "attribute" in values for name, _ in cst.metadata.scope_provider._gen_dotted_names( cast(cst.Attribute, values["attribute"])) } self.names.update(names)
def __extract_variable_name_type(self, node: cst.AnnAssign): """ Extracts a variable's identifier name and its type annotation """ return match.extract( node, match.AnnAssign( # Annotated Assignment target=match.OneOf( match.Name( # Variable name of assignment (only one) value=match.SaveMatchedNode( # Save result match.MatchRegex( r'(.)+'), # Match any string literal "name")), # This extracts variables inside __init__ which typically starts with self (e.g. self.x:int=2) match.Attribute( value=match.Name(value=match.SaveMatchedNode( match.MatchRegex(r'(.)+'), "obj_name" # Object name )), attr=match.Name( match.SaveMatchedNode(match.MatchRegex(r'(.)+'), "name")), )), annotation=match.SaveMatchedNode( # Save result match.DoNotCare(), # Match any string literal "type")))
def test_extract_optional_wildcard_tail(self) -> None: expression = cst.parse_expression("[3]") nodes = m.extract( expression, m.List(elements=[ m.Element(value=m.Integer(value="3")), m.SaveMatchedNode(m.ZeroOrMore(), "tail1"), m.SaveMatchedNode(m.ZeroOrMore(), "tail2"), ]), ) self.assertEqual(nodes, {"tail1": (), "tail2": ()})
def test_extract_sentinel(self) -> None: # Verify behavior when provided a sentinel nothing = m.extract( cst.RemovalSentinel.REMOVE, m.Call(func=m.SaveMatchedNode(m.Name(), name="func")), ) self.assertIsNone(nothing) nothing = m.extract( cst.MaybeSentinel.DEFAULT, m.Call(func=m.SaveMatchedNode(m.Name(), name="func")), ) self.assertIsNone(nothing)
def __get_var_names_counter(self, node, scope): vars_name = match.extractall( node, match.OneOf( match.AssignTarget(target=match.SaveMatchedNode( match.Name(value=match.DoNotCare()), "name")), match.AnnAssign(target=match.SaveMatchedNode( match.Name(value=match.DoNotCare()), "name")))) return Counter([ n['name'].value for n in vars_name if isinstance( self.get_metadata(cst.metadata.ScopeProvider, n['name']), scope) ])
def test_extract_metadata(self) -> None: # Verify true behavior module = cst.parse_module("a + b[c], d(e, f * g)") wrapper = cst.MetadataWrapper(module) expression = cst.ensure_type( cst.ensure_type(wrapper.module.body[0], cst.SimpleStatementLine).body[0], cst.Expr, ).value nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode( m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 1)), ), "left", )))), m.Element(m.Call()), ]), metadata_resolver=wrapper, ) extracted_node = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left self.assertEqual(nodes, {"left": extracted_node}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode( m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 2)), ), "left", )))), m.Element(m.Call()), ]), metadata_resolver=wrapper, ) self.assertIsNone(nodes)
def test_replace_add_one_to_foo_args(self) -> None: def _add_one_to_arg( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return node.deep_replace( # This can be either a node or a sequence, pyre doesn't know. cst.ensure_type(extraction["arg"], cst.CSTNode), # Grab the arg and add one to its value. cst.Integer( str( int( cst.ensure_type(extraction["arg"], cst.Integer).value) + 1)), ) # Verify way more complex transform behavior. original = cst.parse_module( "foo: int = 37\ndef bar(baz: int) -> int:\n return baz\n\nbiz: int = bar(41)\n" ) replaced = cst.ensure_type( m.replace( original, m.Call( func=m.Name("bar"), args=[m.Arg(m.SaveMatchedNode(m.Integer(), "arg"))], ), _add_one_to_arg, ), cst.Module, ).code self.assertEqual( replaced, "foo: int = 37\ndef bar(baz: int) -> int:\n return baz\n\nbiz: int = bar(42)\n", )
def test_replace_sequence_extract(self) -> None: def _reverse_params( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return cst.ensure_type(node, cst.FunctionDef).with_changes( # pyre-ignore We know "params" is a Sequence[Parameters] but asserting that # to pyre is difficult. params=cst.Parameters( params=list(reversed(extraction["params"])))) # Verify that we can still extract sequences with replace. original = cst.parse_module( "def bar(baz: int, foo: int, ) -> int:\n return baz + foo\n") replaced = cst.ensure_type( m.replace( original, m.FunctionDef(params=m.Parameters(params=m.SaveMatchedNode( [m.ZeroOrMore(m.Param())], "params"))), _reverse_params, ), cst.Module, ).code self.assertEqual( replaced, "def bar(foo: int, baz: int, ) -> int:\n return baz + foo\n")
def test_replace_updated_node_changes(self) -> None: def _replace_nested( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return cst.ensure_type(node, cst.Call).with_changes(args=[ cst.Arg( cst.Name(value=cst.ensure_type( cst.ensure_type(extraction["inner"], cst.Call).func, cst.Name, ).value + "_immediate")) ]) original = cst.parse_module( "def foo(val: int) -> int:\n return val\nbar = foo\nbaz = foo\nbiz = foo\nfoo(bar(baz(biz(5))))\n" ) replaced = cst.ensure_type( m.replace( original, m.Call(args=[m.Arg(m.SaveMatchedNode(m.Call(), "inner"))]), _replace_nested, ), cst.Module, ).code self.assertEqual( replaced, "def foo(val: int) -> int:\n return val\nbar = foo\nbaz = foo\nbiz = foo\nfoo(bar_immediate)\n", )
def test_extract_precedence_sequence(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[ m.Arg(m.SaveMatchedNode(m.DoNotCare(), "arg")), m.Arg(m.SaveMatchedNode(m.DoNotCare(), "arg")), ])), ]), ) extracted_node = (cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).args[1].value) self.assertEqual(nodes, {"arg": extracted_node})
def visit_With(self, node: cst.With): with_names = [ n.value for n in match.findall( match.extract( node, match.With(items=match.SaveMatchedNode( match.DoNotCare(), 'with_items')))['with_items'][0], match.Name( value=match.SaveMatchedNode(match.DoNotCare(), 'name'))) ] if len(self.stack) > 0: self.fn_may_args_var_use.append(with_names) if len(self.cls_stack) > 0: if self.cls_stack[0].name in with_names: self.cls_may_vars_use.append(with_names) self.__find_module_vars_use(with_names)
def leave_BinaryOperation( self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation ) -> cst.BaseExpression: expr_key = "expr" extracts = m.extract( original_node, m.BinaryOperation( left=m.MatchIfTrue(_match_simple_string), operator=m.Modulo(), right=m.SaveMatchedNode( m.MatchIfTrue(_gen_match_simple_expression(self.module)), expr_key, ), ), ) if extracts: expr = extracts[expr_key] parts = [] simple_string = cst.ensure_type(original_node.left, cst.SimpleString) innards = simple_string.raw_value.replace("{", "{{").replace("}", "}}") tokens = innards.split("%s") token = tokens[0] if len(token) > 0: parts.append(cst.FormattedStringText(value=token)) expressions = ( [elm.value for elm in expr.elements] if isinstance(expr, cst.Tuple) else [expr] ) escape_transformer = EscapeStringQuote(simple_string.quote) i = 1 while i < len(tokens): if i - 1 >= len(expressions): # the %-string doesn't come with same number of elements in tuple return original_node try: parts.append( cst.FormattedStringExpression( expression=cast( cst.BaseExpression, expressions[i - 1].visit(escape_transformer), ) ) ) except Exception: return original_node token = tokens[i] if len(token) > 0: parts.append(cst.FormattedStringText(value=token)) i += 1 start = f"f{simple_string.prefix}{simple_string.quote}" return cst.FormattedString( parts=parts, start=start, end=simple_string.quote ) return original_node
def test_extractall_simple(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") matches = extractall(expression, m.Arg(m.SaveMatchedNode(~m.Name(), "expr"))) extracted_args = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call ).args self.assertEqual( matches, [{"expr": extracted_args[1].value}, {"expr": extracted_args[2].value}], )
def __name2annotation(self, type_name: str): ext_annot = lambda t: match.extract( cst.parse_module("x: %s=None" % t).body[0].body[0], match.AnnAssign(target=match.Name(value=match.DoNotCare()), annotation=match.SaveMatchedNode( match.DoNotCare(), "type")))['type'] try: return ext_annot(type_name) except cst._exceptions.ParserSyntaxError: return None
def visit_Element(self, node: libcst.Element) -> bool: # See if this is a entry that is a string. extraction = self.extract( node, m.Element(m.SaveMatchedNode(m.SimpleString(), "string"))) if extraction: string = ensure_type(extraction["string"], libcst.SimpleString) self.explicit_exported_objects.add(string.evaluated_value) # Don't need to visit children return False
def test_extract_sequence_multiple_wildcards(self) -> None: expression = cst.parse_expression("1, 2, 3, 4") nodes = m.extract( expression, m.Tuple(elements=( m.SaveMatchedNode(m.ZeroOrMore(), "head"), m.SaveMatchedNode(m.Element(value=m.Integer( value="3")), "element"), m.SaveMatchedNode(m.ZeroOrMore(), "tail"), )), ) tuple_elements = cst.ensure_type(expression, cst.Tuple).elements self.assertEqual( nodes, { "head": tuple(tuple_elements[:2]), "element": tuple_elements[2], "tail": tuple(tuple_elements[3:]), }, )
def visit_ImportAlias(self, node: cst.ImportAlias): """ Extracts imports. Even if an Import has no alias, the Import node will still have an ImportAlias with it's real name. """ # Extract information from the Import Alias node. # Both the (real) import name and the alias are extracted. # Result is returned as a dictionary with a name:node KV pair, # and an alias:name KV pair if an alias is present. import_info = match.extract( node, match.ImportAlias( asname=match.AsName( # Attempt to match alias name=match.Name( # Check for alias name value=match.SaveMatchedNode( # Save the alias name match.MatchRegex( r'(.)+'), # Match any string literal "alias"))) | ~match.AsName(), # If there is no AsName, we should still match. We negate AsName to and OR to form a tautology. name=match.SaveMatchedNode( # Match & save name of import match.DoNotCare(), "name"))) # Add import if a name could be extracted. if "name" in import_info: # Append import name to imports list. # We convert the ImportAlias node to the code representation for simplified conversion # of multi-level imports (e.g. import.x.y.z) # TODO: This might be un-needed after implementing import type extraction change. # TODO: So far, no differentiation between import and from imports. import_name = self.__convert_node_to_code(import_info["name"]) import_name = self.__clean_string_whitespace(import_name) self.imports.append(import_name) if "alias" in import_info: import_name = self.__clean_string_whitespace(import_info["alias"]) self.imports.append(import_name) # No need to traverse children, as we already performed the matching. return False
def test_extract_tautology(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.SaveMatchedNode( m.Tuple(elements=[ m.Element(m.BinaryOperation()), m.Element(m.Call()) ]), name="node", ), ) self.assertEqual(nodes, {"node": expression})
def visit_Assign(self, node: cst.Assign) -> None: d = m.extract( node, m.Assign( targets=(m.AssignTarget(target=m.Name("CARBON_EXTS")), ), value=m.SaveMatchedNode(m.List(), "list"), ), ) if d: assert isinstance(d["list"], cst.List) for item in d["list"].elements: if isinstance(item.value, cst.SimpleString): self.extension_names.append(item.value.evaluated_value)
def __extract_variable_name(self, node: cst.AssignTarget): extracted_var_names = match.extract( node, match.AssignTarget( # Assignment operator target=match.OneOf( # Two cases exist match.Name( # Single target value=match.SaveMatchedNode( # Save result match.MatchRegex( r'(.)+'), # Match any string literal "name")), match.Tuple( # Multi-target elements=match.SaveMatchedNode( # Save result match.DoNotCare(), # Type of list "names")), # This extracts variables inside __init__ without type annotation (e.g. self.x=2) match.Attribute( value=match.Name(value=match.SaveMatchedNode( match.MatchRegex(r'(.)+'), "obj_name" # Object name )), attr=match.Name( match.SaveMatchedNode(match.MatchRegex(r'(.)+'), "name")), )))) if extracted_var_names is not None: if "name" in extracted_var_names: t = self.__get_type_from_metadata(node.target) extracted_var_names['type'] = (t, INF_TYPE_ANNOT if t else UNK_TYPE_ANNOT) return extracted_var_names elif "names" in extracted_var_names: return { 'names': self.__extract_names_multi_assign( list(extracted_var_names['names'])) } else: return extracted_var_names
def test_extract_multiple(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element(m.Call(func=m.SaveMatchedNode(m.Name(), "func"))), ]), ) extracted_node_left = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left extracted_node_func = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).func self.assertEqual(nodes, { "left": extracted_node_left, "func": extracted_node_func })
def test_extract_sequence(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=m.SaveMatchedNode([m.ZeroOrMore()], "args"))), ]), ) extracted_seq = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).args self.assertEqual(nodes, {"args": extracted_seq})
def test_extract_optional_wildcard(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[ m.ZeroOrMore(), m.ZeroOrOne( m.Arg(m.SaveMatchedNode(m.Attribute(), "arg"))), ])), ]), ) self.assertEqual(nodes, {})
def visit_Call(self, node: cst.Call) -> None: # print(node) d = m.extract( node, m.Call( func=m.OneOf(m.Name("Extension"), m.Name("addMacExtension")), args=( m.Arg(value=m.SaveMatchedNode(m.SimpleString(), "extension_name")), m.ZeroOrMore(m.DoNotCare()), ), ), ) if d: assert isinstance(d["extension_name"], cst.SimpleString) self.extension_names.append(d["extension_name"].evaluated_value)
def visit_If(self, node: cst.If): if_names = [ n.value for n in match.findall( node.test, match.Name( value=match.SaveMatchedNode(match.DoNotCare(), 'name'))) ] if len(self.cls_stack) > 0: if self.cls_stack[0].name in if_names: self.cls_may_vars_use.append(if_names) if len(self.stack) > 0: self.fn_may_args_var_use.append(if_names) self.__find_module_vars_use(if_names)
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine): smt_names = [ n.value for n in match.findall( node, match.Name( value=match.SaveMatchedNode(match.DoNotCare(), 'name'))) ] if len(self.stack) > 0: self.fn_may_args_var_use.append(smt_names) if len(self.cls_stack) > 0: if self.cls_stack[0].name in smt_names: self.cls_may_vars_use.append(smt_names) self.__find_module_vars_use(smt_names)
def test_extract_optional_wildcard_present(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[ m.DoNotCare(), m.DoNotCare(), m.ZeroOrOne( m.Arg(m.SaveMatchedNode(m.Attribute(), "arg"))), ])), ]), ) extracted_node = (cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).args[2].value) self.assertEqual(nodes, {"arg": extracted_node})