def test_list(self, _name: str, expression1: str, expression2: str) -> None: matcher = craftier.matcher.from_node( libcst.parse_expression(expression1), {}) self.assertTrue( craftier.matcher.matches(libcst.parse_expression(expression2), matcher))
def test_generator_many_argument_function_call(self) -> None: node = libcst.parse_expression("(x for x in foo)").with_changes( lpar=[], rpar=[] ) new_node = parenthesize.parenthesize_using_previous( node, libcst.parse_expression("max((x for x in foo), foo)") ) self.assert_has_parentheses(new_node)
class ExpressionTest(UnitTest): @data_provider(( ("a string", "a string"), (cst.Name("a_name"), "a_name"), (cst.parse_expression("a.b.c"), "a.b.c"), (cst.parse_expression("a.b()"), "a.b"), (cst.parse_expression("a.b.c[i]"), "a.b.c"), (cst.parse_statement("def fun(): pass"), "fun"), (cst.parse_statement("class cls: pass"), "cls"), ( cst.Decorator( ensure_type(cst.parse_expression("a.b.c.d"), cst.Attribute)), "a.b.c.d", ), (cst.parse_statement("(a.b()).c()"), None), # not a supported Node type )) def test_get_full_name_for_expression( self, input: Union[str, cst.CSTNode], output: Optional[str], ) -> None: self.assertEqual(get_full_name_for_node(input), output) if output is None: with self.assertRaises(Exception): get_full_name_for_node_or_raise(input) else: self.assertEqual(get_full_name_for_node_or_raise(input), output) def test_simplestring_evaluated_value(self) -> None: raw_string = '"a string."' node = ensure_type(cst.parse_expression(raw_string), cst.SimpleString) self.assertEqual(node.value, raw_string) self.assertEqual(node.evaluated_value, literal_eval(raw_string)) def test_integer_evaluated_value(self) -> None: raw_value = "5" node = ensure_type(cst.parse_expression(raw_value), cst.Integer) self.assertEqual(node.value, raw_value) self.assertEqual(node.evaluated_value, literal_eval(raw_value)) def test_float_evaluated_value(self) -> None: raw_value = "5.5" node = ensure_type(cst.parse_expression(raw_value), cst.Float) self.assertEqual(node.value, raw_value) self.assertEqual(node.evaluated_value, literal_eval(raw_value)) def test_complex_evaluated_value(self) -> None: raw_value = "5j" node = ensure_type(cst.parse_expression(raw_value), cst.Imaginary) self.assertEqual(node.value, raw_value) self.assertEqual(node.evaluated_value, literal_eval(raw_value))
def test_concatenated_string_evaluated_value(self) -> None: code = '"This " "is " "a " "concatenated " "string."' node = ensure_type(cst.parse_expression(code), cst.ConcatenatedString) self.assertEqual(node.evaluated_value, "This is a concatenated string.") code = 'b"A concatenated" b" byte."' node = ensure_type(cst.parse_expression(code), cst.ConcatenatedString) self.assertEqual(node.evaluated_value, b"A concatenated byte.") code = '"var=" f" {var}"' node = ensure_type(cst.parse_expression(code), cst.ConcatenatedString) self.assertEqual(node.evaluated_value, None) code = '"var" "=" f" {var}"' node = ensure_type(cst.parse_expression(code), cst.ConcatenatedString) self.assertEqual(node.evaluated_value, None)
def test_fstring_with_placeholders(self) -> None: matcher = craftier.matcher.from_node( libcst.parse_expression("f'{x} and {y}'"), { "x": libcst.matchers.DoNotCare(), "y": libcst.matchers.Integer() }, ) self.assertTrue( craftier.matcher.matches( libcst.parse_expression("f'{a + b} and {32}'"), matcher)) self.assertFalse( craftier.matcher.matches( libcst.parse_expression("f'{a + b} and {z}'"), matcher))
def test_placeholders( self, _name: str, expression1: str, expression2: str, placeholders: Set[str], ) -> None: matcher = craftier.matcher.from_node( libcst.parse_expression(expression1), {p: libcst.matchers.DoNotCare() for p in placeholders}, ) self.assertTrue( craftier.matcher.matches(libcst.parse_expression(expression2), matcher))
def test_extract_predicates(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element( m.Call(func=m.SaveMatchedNode(m.Name(), "func") | m.SaveMatchedNode(m.Attribute(), "attr"))), ]), ) extracted_node_left = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left extracted_node_func = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).func self.assertEqual(nodes, { "left": extracted_node_left, "func": extracted_node_func }) expression = cst.parse_expression("a + b[c], d.z(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element( m.Call(func=m.SaveMatchedNode(m.Name(), "func") | m.SaveMatchedNode(m.Attribute(), "attr"))), ]), ) extracted_node_left = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left extracted_node_attr = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).func self.assertEqual(nodes, { "left": extracted_node_left, "attr": extracted_node_attr })
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: new_bases: List[cst.Arg] = [] namedtuple_base: Optional[cst.Arg] = None # Need to examine the original node's bases since they are directly tied to import metadata for base_class in original_node.bases: # Compare the base class's qualified name against the expected typing.NamedTuple if not QualifiedNameProvider.has_name(self, base_class.value, self.qualified_namedtuple): # Keep all bases that are not of type typing.NamedTuple new_bases.append(base_class) else: namedtuple_base = base_class # We still want to return the updated node in case some of its children have been modified if namedtuple_base is None: return updated_node AddImportsVisitor.add_needed_import(self.context, "attr", "dataclass") AddImportsVisitor.add_needed_import(self.context, "pydantic.dataclasses", "dataclass") RemoveImportsVisitor.remove_unused_import_by_node(self.context, namedtuple_base.value) call = cst.ensure_type( cst.parse_expression("dataclass(frozen=False)", config=self.module.config_for_parsing), cst.Call, ) return updated_node.with_changes( lpar=cst.MaybeSentinel.DEFAULT, rpar=cst.MaybeSentinel.DEFAULT, bases=new_bases, decorators=[*original_node.decorators, cst.Decorator(decorator=call)], )
def leave_SimpleString( self, original_node: libcst.SimpleString, updated_node: libcst.SimpleString ) -> Union[libcst.SimpleString, libcst.BaseExpression]: AddImportsVisitor.add_needed_import(self.context, "__future__", "annotations") return parse_expression( literal_eval(updated_node.value), config=self.module.config_for_parsing )
def test_extract_sequence_element(self) -> None: # Verify true behavior expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[m.SaveMatchedNode(m.ZeroOrMore(), "args")])), ]), ) extracted_seq = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).args self.assertEqual(nodes, {"args": extracted_seq}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[ m.SaveMatchedNode(m.ZeroOrMore(m.Arg(m.Subscript())), "args") ])), ]), ) self.assertIsNone(nodes)
def test_extract_simple(self) -> None: # Verify true behavior expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element(m.Call()), ]), ) extracted_node = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left self.assertEqual(nodes, {"left": extracted_node}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Subscript(), "left"))), m.Element(m.Call()), ]), ) self.assertIsNone(nodes)
def leave_Attribute( self, original_node: cst.Attribute, updated_node: cst.Attribute) -> Union[cst.Name, cst.Attribute]: full_name_for_node = get_full_name_for_node(original_node) if full_name_for_node is None: raise Exception("Could not parse full name for Attribute node.") full_replacement_name = self.gen_replacement(full_name_for_node) # If a node has no associated QualifiedName, we are still inside an import statement. inside_import_statement: bool = not self.get_metadata( QualifiedNameProvider, original_node, set()) if (QualifiedNameProvider.has_name( self, original_node, self.old_name, ) or (inside_import_statement and full_replacement_name == self.new_name)): new_value, new_attr = self.new_module, self.new_mod_or_obj if not inside_import_statement: self.scheduled_removals.add(original_node.value) if full_replacement_name == self.new_name: return updated_node.with_changes( value=cst.parse_expression(new_value), attr=cst.Name(value=new_attr.rstrip(".")), ) return self.gen_name_or_attr_node(new_attr) return updated_node
def parse_template_expression( template: str, config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG, **template_replacements: ValidReplacementType, ) -> cst.BaseExpression: """ Accepts an expression template on a single line. Leading and trailing whitespace is not valid (there’s nowhere to store it on the expression node). Any :class:`~libcst.CSTNode` provided as a keyword argument to this function will be inserted into the template at the appropriate location similar to an f-string expansion. For example:: expression = parse_template_expression("x + {foo}", foo=Name("y"))) The above code will parse to a :class:`~libcst.BinaryOperation` expression adding two names (``x`` and ``y``) together. Remember that if you are parsing a template as part of a substitution inside a transform, its considered :ref:`best practice <libcst-config_best_practice>` to pass in a ``config`` from the current module under transformation. """ source = mangle_template(template, {name for name in template_replacements}) expression = cst.parse_expression(source, config) new_expression = ensure_type( unmangle_nodes(expression, template_replacements), cst.BaseExpression) new_expression.visit( TemplateChecker({name for name in template_replacements})) return new_expression
def _get_clean_type(typeobj: object) -> str: """ Given a type object as returned by dataclasses, sanitize it and convert it to a type string that is appropriate for our codegen below. """ # First, get the type as a parseable expression. typestr = repr(typeobj) if typestr.startswith("<class '") and typestr.endswith("'>"): typestr = typestr[8:-2] # Now, parse the expression with LibCST. cleanser = CleanseFullTypeNames() typecst = parse_expression(typestr) typecst = typecst.visit(cleanser) clean_type: Optional[cst.CSTNode] = None # Now, convert the type to allow for DoNotCareSentinel values. if isinstance(typecst, cst.Subscript): if typecst.value.deep_equals(cst.Name("Union")): # We can modify this as-is to add our type clean_type = typecst.with_changes( slice=[*typecst.slice, _get_do_not_care()] ) elif typecst.value.deep_equals(cst.Name("Literal")): clean_type = _get_wrapped_union_type(typecst, _get_do_not_care()) elif typecst.value.deep_equals(cst.Name("Sequence")): clean_type = _get_wrapped_union_type(typecst, _get_do_not_care()) elif isinstance(typecst, (cst.Name, cst.SimpleString)): clean_type = _get_wrapped_union_type(typecst, _get_do_not_care()) # Now, clean up the outputted type and return the code it generates. If # for some reason we encounter a new node type, raise so we can triage. if clean_type is None: raise Exception(f"Don't support {typecst}") else: # First, add DoNotCareSentinel to all sequences, so that a sequence # can be defined partially with explicit DoNotCare() values for some # slots. clean_type = ensure_type( clean_type.visit(AddDoNotCareToSequences()), cst.CSTNode ) # Now, double-quote any types we parsed and repr'd, for consistency. clean_type = ensure_type(clean_type.visit(DoubleQuoteStrings()), cst.CSTNode) # Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage. # This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any # spot that we would have originally allowed a SomeType. clean_type = ensure_type( clean_type.visit(AddLogicAndLambdaMatcherToUnions()), cst.CSTNode ) # Now, insert AtMostN and AtLeastN into sequence unions, so we can typecheck # them. This relies on the previous OneOf/AllOf insertion to ensure that all # sequences we care about are Sequence[Union[<x>]]. clean_type = ensure_type( clean_type.visit(AddWildcardsToSequenceUnions()), cst.CSTNode ) # Finally, generate the code given a default Module so we can spit it out. return cst.Module(body=()).code_for_node(clean_type)
def leave_Return( self, original_node: cst.Return, updated_node: cst.Return ) -> Union[cst.Return, RemovalSentinel, FlattenSentinel[cst.BaseSmallStatement]]: return FlattenSentinel([ cst.Expr(parse_expression("print('returning')")), updated_node, ])
def test_generator_return(self) -> None: node = libcst.parse_expression("(x for x in foo)").with_changes( lpar=[], rpar=[] ) new_node = parenthesize.parenthesize_using_previous( node, libcst.parse_statement("return (x for x in foo)") ) self.assert_has_parentheses(new_node)
def leave_SimpleString( self, original_node: libcst.SimpleString, updated_node: libcst.SimpleString ) -> Union[libcst.SimpleString, libcst.BaseExpression]: AddImportsVisitor.add_needed_import(self.context, "__future__", "annotations") # Just use LibCST to evaluate the expression itself, and insert that as the # annotation. return parse_expression( updated_node.evaluated_value, config=self.module.config_for_parsing )
def gen_name_or_attr_node( self, dotted_expression: str) -> Union[cst.Attribute, cst.Name]: name_or_attr_node: cst.BaseExpression = cst.parse_expression( dotted_expression) if not isinstance(name_or_attr_node, (cst.Name, cst.Attribute)): raise Exception( "`parse_expression()` on dotted path returned non-Attribute-or-Name." ) return name_or_attr_node
def extract_names_from_type_annot(type_annot: str): """ Extracts all the names/identifiers from a type annotation """ return [ n.value for n in match.findall(cst.parse_expression(type_annot), match.Name(value=match.DoNotCare())) ]
def test_extractall_simple(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") matches = extractall(expression, m.Arg(m.SaveMatchedNode(~m.Name(), "expr"))) extracted_args = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call ).args self.assertEqual( matches, [{"expr": extracted_args[1].value}, {"expr": extracted_args[2].value}], )
def parse_arg(arg_str: str) -> Arg: """Build a `Arg` instance based on its string representation. Instantiating it from scratch is cumbersome, this helper generates a function call with the given argument and extract it from the tree. """ call_result = parse_expression(f"call({arg_str})") if isinstance(call_result, Call): return call_result.args[0] raise AssertionError(f"Unexpected type for: {call_result}")
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom: module = updated_node.module if module is None: return updated_node imported_module_name = get_full_name_for_node(module) names = original_node.names if imported_module_name is None or not isinstance(names, Sequence): return updated_node else: new_names = [] for import_alias in names: alias_name = get_full_name_for_node(import_alias.name) if alias_name is not None: qual_name = f"{imported_module_name}.{alias_name}" if self.old_name == qual_name: replacement_module = self.gen_replacement_module( imported_module_name) replacement_obj = self.gen_replacement(alias_name) if not replacement_obj: # The user has requested an `import` statement rather than an `from ... import`. # This will be taken care of in `leave_Module`, in the meantime, schedule for potential removal. new_names.append(import_alias) self.scheduled_removals.add(original_node) continue new_import_alias_name: Union[ cst.Attribute, cst.Name] = self.gen_name_or_attr_node( replacement_obj) # Rename on the spot only if this is the only imported name under the module. if len(names) == 1: self.bypass_import = True return updated_node.with_changes( module=cst.parse_expression( replacement_module), names=(cst.ImportAlias( name=new_import_alias_name), ), ) # Or if the module name is to stay the same. elif replacement_module == imported_module_name: self.bypass_import = True new_names.append( cst.ImportAlias(name=new_import_alias_name)) else: if self.old_name.startswith(qual_name + "."): # This import might be in use elsewhere in the code, so schedule a potential removal. self.scheduled_removals.add(original_node) new_names.append(import_alias) return updated_node.with_changes(names=new_names) return updated_node
def test_extract_optional_wildcard_tail(self) -> None: expression = cst.parse_expression("[3]") nodes = m.extract( expression, m.List(elements=[ m.Element(value=m.Integer(value="3")), m.SaveMatchedNode(m.ZeroOrMore(), "tail1"), m.SaveMatchedNode(m.ZeroOrMore(), "tail2"), ]), ) self.assertEqual(nodes, {"tail1": (), "tail2": ()})
def test_string_prefix_and_quotes(self) -> None: """ Test our helpers out for various strings. """ emptybytestring = cst.ensure_type(parse_expression('b""'), cst.SimpleString) bytestring = cst.ensure_type(parse_expression('b"abc"'), cst.SimpleString) multilinestring = cst.ensure_type(parse_expression('""""""'), cst.SimpleString) formatstring = cst.ensure_type(parse_expression('f""""""'), cst.FormattedString) self.assertEqual(emptybytestring.prefix, "b") self.assertEqual(emptybytestring.quote, '"') self.assertEqual(emptybytestring.raw_value, "") self.assertEqual(bytestring.prefix, "b") self.assertEqual(bytestring.quote, '"') self.assertEqual(bytestring.raw_value, "abc") self.assertEqual(multilinestring.prefix, "") self.assertEqual(multilinestring.quote, '"""') self.assertEqual(multilinestring.raw_value, "") self.assertEqual(formatstring.prefix, "f") self.assertEqual(formatstring.quote, '"""')
def visit_SimpleString(self, node: cst.SimpleString) -> None: if not self.has_future_annotations_import: return if self.in_annotation and not self.in_literal: # This is not allowed past Python3.7 since it's no longer necessary. self.report( node, replacement=cst.parse_expression( node.evaluated_value, config=self.context.wrapper.module.config_for_parsing, ), )
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: (namedtuple_base, new_bases) = self.partition_bases(original_node.bases) if namedtuple_base is not None: call = ensure_type(parse_expression("dataclass(frozen=True)"), cst.Call) replacement = original_node.with_changes( lpar=MaybeSentinel.DEFAULT, rpar=MaybeSentinel.DEFAULT, bases=new_bases, decorators=list(original_node.decorators) + [cst.Decorator(decorator=call)], ) self.report(original_node, replacement=replacement)
def test_parsing_compilable_expression_strings(self, source_code: str) -> None: """Much like statements, but for expressions this time. We change the start production of the grammar, the compile mode, and the libCST parse function, but codegen is as for statements. """ self.reject_invalid_code(source_code, mode="eval") tree = libcst.parse_expression(source_code) self.verify_identical_asts(source_code, libcst.Module([]).code_for_node(tree), mode="eval")
def test_extract_tautology(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.SaveMatchedNode( m.Tuple(elements=[ m.Element(m.BinaryOperation()), m.Element(m.Call()) ]), name="node", ), ) self.assertEqual(nodes, {"node": expression})
def test_extract_sequence(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=m.SaveMatchedNode([m.ZeroOrMore()], "args"))), ]), ) extracted_seq = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).args self.assertEqual(nodes, {"args": extracted_seq})
def update_call_args(self, node: Call) -> Sequence[Arg]: """Update first argument to convert integer for minutes to timedelta.""" AddImportsVisitor.add_needed_import( context=self.context, module="datetime", obj="timedelta", ) offset_arg, *other_args = node.args integer_value = offset_arg.value if not isinstance(integer_value, Integer): raise AssertionError(f"Unexpected type for: {integer_value}") timedelta_call = parse_expression(f"timedelta(minutes={integer_value.value})") new_offset_arg = offset_arg.with_changes(value=timedelta_call) return (new_offset_arg, *other_args)