Exemple #1
0
 def test_list(self, _name: str, expression1: str,
               expression2: str) -> None:
     matcher = craftier.matcher.from_node(
         libcst.parse_expression(expression1), {})
     self.assertTrue(
         craftier.matcher.matches(libcst.parse_expression(expression2),
                                  matcher))
Exemple #2
0
 def test_generator_many_argument_function_call(self) -> None:
     node = libcst.parse_expression("(x for x in foo)").with_changes(
         lpar=[], rpar=[]
     )
     new_node = parenthesize.parenthesize_using_previous(
         node, libcst.parse_expression("max((x for x in foo), foo)")
     )
     self.assert_has_parentheses(new_node)
Exemple #3
0
class ExpressionTest(UnitTest):
    @data_provider((
        ("a string", "a string"),
        (cst.Name("a_name"), "a_name"),
        (cst.parse_expression("a.b.c"), "a.b.c"),
        (cst.parse_expression("a.b()"), "a.b"),
        (cst.parse_expression("a.b.c[i]"), "a.b.c"),
        (cst.parse_statement("def fun():  pass"), "fun"),
        (cst.parse_statement("class cls:  pass"), "cls"),
        (
            cst.Decorator(
                ensure_type(cst.parse_expression("a.b.c.d"), cst.Attribute)),
            "a.b.c.d",
        ),
        (cst.parse_statement("(a.b()).c()"),
         None),  # not a supported Node type
    ))
    def test_get_full_name_for_expression(
        self,
        input: Union[str, cst.CSTNode],
        output: Optional[str],
    ) -> None:
        self.assertEqual(get_full_name_for_node(input), output)
        if output is None:
            with self.assertRaises(Exception):
                get_full_name_for_node_or_raise(input)
        else:
            self.assertEqual(get_full_name_for_node_or_raise(input), output)

    def test_simplestring_evaluated_value(self) -> None:
        raw_string = '"a string."'
        node = ensure_type(cst.parse_expression(raw_string), cst.SimpleString)
        self.assertEqual(node.value, raw_string)
        self.assertEqual(node.evaluated_value, literal_eval(raw_string))

    def test_integer_evaluated_value(self) -> None:
        raw_value = "5"
        node = ensure_type(cst.parse_expression(raw_value), cst.Integer)
        self.assertEqual(node.value, raw_value)
        self.assertEqual(node.evaluated_value, literal_eval(raw_value))

    def test_float_evaluated_value(self) -> None:
        raw_value = "5.5"
        node = ensure_type(cst.parse_expression(raw_value), cst.Float)
        self.assertEqual(node.value, raw_value)
        self.assertEqual(node.evaluated_value, literal_eval(raw_value))

    def test_complex_evaluated_value(self) -> None:
        raw_value = "5j"
        node = ensure_type(cst.parse_expression(raw_value), cst.Imaginary)
        self.assertEqual(node.value, raw_value)
        self.assertEqual(node.evaluated_value, literal_eval(raw_value))
Exemple #4
0
 def test_concatenated_string_evaluated_value(self) -> None:
     code = '"This " "is " "a " "concatenated " "string."'
     node = ensure_type(cst.parse_expression(code), cst.ConcatenatedString)
     self.assertEqual(node.evaluated_value, "This is a concatenated string.")
     code = 'b"A concatenated" b" byte."'
     node = ensure_type(cst.parse_expression(code), cst.ConcatenatedString)
     self.assertEqual(node.evaluated_value, b"A concatenated byte.")
     code = '"var=" f" {var}"'
     node = ensure_type(cst.parse_expression(code), cst.ConcatenatedString)
     self.assertEqual(node.evaluated_value, None)
     code = '"var" "=" f" {var}"'
     node = ensure_type(cst.parse_expression(code), cst.ConcatenatedString)
     self.assertEqual(node.evaluated_value, None)
Exemple #5
0
 def test_fstring_with_placeholders(self) -> None:
     matcher = craftier.matcher.from_node(
         libcst.parse_expression("f'{x} and {y}'"),
         {
             "x": libcst.matchers.DoNotCare(),
             "y": libcst.matchers.Integer()
         },
     )
     self.assertTrue(
         craftier.matcher.matches(
             libcst.parse_expression("f'{a + b} and {32}'"), matcher))
     self.assertFalse(
         craftier.matcher.matches(
             libcst.parse_expression("f'{a + b} and {z}'"), matcher))
Exemple #6
0
 def test_placeholders(
     self,
     _name: str,
     expression1: str,
     expression2: str,
     placeholders: Set[str],
 ) -> None:
     matcher = craftier.matcher.from_node(
         libcst.parse_expression(expression1),
         {p: libcst.matchers.DoNotCare()
          for p in placeholders},
     )
     self.assertTrue(
         craftier.matcher.matches(libcst.parse_expression(expression2),
                                  matcher))
Exemple #7
0
    def test_extract_predicates(self) -> None:
        expression = cst.parse_expression("a + b[c], d(e, f * g)")
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(
                        left=m.SaveMatchedNode(m.Name(), "left"))),
                m.Element(
                    m.Call(func=m.SaveMatchedNode(m.Name(), "func")
                           | m.SaveMatchedNode(m.Attribute(), "attr"))),
            ]),
        )
        extracted_node_left = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[0].value,
            cst.BinaryOperation,
        ).left
        extracted_node_func = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[1].value,
            cst.Call).func
        self.assertEqual(nodes, {
            "left": extracted_node_left,
            "func": extracted_node_func
        })

        expression = cst.parse_expression("a + b[c], d.z(e, f * g)")
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(
                        left=m.SaveMatchedNode(m.Name(), "left"))),
                m.Element(
                    m.Call(func=m.SaveMatchedNode(m.Name(), "func")
                           | m.SaveMatchedNode(m.Attribute(), "attr"))),
            ]),
        )
        extracted_node_left = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[0].value,
            cst.BinaryOperation,
        ).left
        extracted_node_attr = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[1].value,
            cst.Call).func
        self.assertEqual(nodes, {
            "left": extracted_node_left,
            "attr": extracted_node_attr
        })
Exemple #8
0
    def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
        new_bases: List[cst.Arg] = []
        namedtuple_base: Optional[cst.Arg] = None

        # Need to examine the original node's bases since they are directly tied to import metadata
        for base_class in original_node.bases:
            # Compare the base class's qualified name against the expected typing.NamedTuple
            if not QualifiedNameProvider.has_name(self, base_class.value, self.qualified_namedtuple):
                # Keep all bases that are not of type typing.NamedTuple
                new_bases.append(base_class)
            else:
                namedtuple_base = base_class

        # We still want to return the updated node in case some of its children have been modified
        if namedtuple_base is None:
            return updated_node

        AddImportsVisitor.add_needed_import(self.context, "attr", "dataclass")
        AddImportsVisitor.add_needed_import(self.context, "pydantic.dataclasses", "dataclass")
        RemoveImportsVisitor.remove_unused_import_by_node(self.context, namedtuple_base.value)

        call = cst.ensure_type(
            cst.parse_expression("dataclass(frozen=False)", config=self.module.config_for_parsing),
            cst.Call,
        )
        return updated_node.with_changes(
            lpar=cst.MaybeSentinel.DEFAULT,
            rpar=cst.MaybeSentinel.DEFAULT,
            bases=new_bases,
            decorators=[*original_node.decorators, cst.Decorator(decorator=call)],
        )
 def leave_SimpleString(
     self, original_node: libcst.SimpleString, updated_node: libcst.SimpleString
 ) -> Union[libcst.SimpleString, libcst.BaseExpression]:
     AddImportsVisitor.add_needed_import(self.context, "__future__", "annotations")
     return parse_expression(
         literal_eval(updated_node.value), config=self.module.config_for_parsing
     )
Exemple #10
0
    def test_extract_sequence_element(self) -> None:
        # Verify true behavior
        expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)")
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.DoNotCare(),
                m.Element(
                    m.Call(args=[m.SaveMatchedNode(m.ZeroOrMore(), "args")])),
            ]),
        )
        extracted_seq = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[1].value,
            cst.Call).args
        self.assertEqual(nodes, {"args": extracted_seq})

        # Verify false behavior
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.DoNotCare(),
                m.Element(
                    m.Call(args=[
                        m.SaveMatchedNode(m.ZeroOrMore(m.Arg(m.Subscript())),
                                          "args")
                    ])),
            ]),
        )
        self.assertIsNone(nodes)
Exemple #11
0
    def test_extract_simple(self) -> None:
        # Verify true behavior
        expression = cst.parse_expression("a + b[c], d(e, f * g)")
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(
                        left=m.SaveMatchedNode(m.Name(), "left"))),
                m.Element(m.Call()),
            ]),
        )
        extracted_node = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[0].value,
            cst.BinaryOperation,
        ).left
        self.assertEqual(nodes, {"left": extracted_node})

        # Verify false behavior
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(
                        left=m.SaveMatchedNode(m.Subscript(), "left"))),
                m.Element(m.Call()),
            ]),
        )
        self.assertIsNone(nodes)
Exemple #12
0
    def leave_Attribute(
            self, original_node: cst.Attribute,
            updated_node: cst.Attribute) -> Union[cst.Name, cst.Attribute]:
        full_name_for_node = get_full_name_for_node(original_node)
        if full_name_for_node is None:
            raise Exception("Could not parse full name for Attribute node.")
        full_replacement_name = self.gen_replacement(full_name_for_node)

        # If a node has no associated QualifiedName, we are still inside an import statement.
        inside_import_statement: bool = not self.get_metadata(
            QualifiedNameProvider, original_node, set())
        if (QualifiedNameProvider.has_name(
                self,
                original_node,
                self.old_name,
        ) or (inside_import_statement
              and full_replacement_name == self.new_name)):
            new_value, new_attr = self.new_module, self.new_mod_or_obj
            if not inside_import_statement:
                self.scheduled_removals.add(original_node.value)
            if full_replacement_name == self.new_name:
                return updated_node.with_changes(
                    value=cst.parse_expression(new_value),
                    attr=cst.Name(value=new_attr.rstrip(".")),
                )

            return self.gen_name_or_attr_node(new_attr)

        return updated_node
def parse_template_expression(
    template: str,
    config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG,
    **template_replacements: ValidReplacementType,
) -> cst.BaseExpression:
    """
    Accepts an expression template on a single line. Leading and trailing whitespace
    is not valid (there’s nowhere to store it on the expression node). Any
    :class:`~libcst.CSTNode` provided as a keyword argument to this function will
    be inserted into the template at the appropriate location similar to an
    f-string expansion. For example::

      expression = parse_template_expression("x + {foo}", foo=Name("y")))

    The above code will parse to a :class:`~libcst.BinaryOperation` expression
    adding two names (``x`` and ``y``) together.

    Remember that if you are parsing a template as part of a substitution inside
    a transform, its considered :ref:`best practice <libcst-config_best_practice>`
    to pass in a ``config`` from the current module under transformation.
    """

    source = mangle_template(template,
                             {name
                              for name in template_replacements})
    expression = cst.parse_expression(source, config)
    new_expression = ensure_type(
        unmangle_nodes(expression, template_replacements), cst.BaseExpression)
    new_expression.visit(
        TemplateChecker({name
                         for name in template_replacements}))
    return new_expression
def _get_clean_type(typeobj: object) -> str:
    """
    Given a type object as returned by dataclasses, sanitize it and convert it
    to a type string that is appropriate for our codegen below.
    """

    # First, get the type as a parseable expression.
    typestr = repr(typeobj)
    if typestr.startswith("<class '") and typestr.endswith("'>"):
        typestr = typestr[8:-2]

    # Now, parse the expression with LibCST.
    cleanser = CleanseFullTypeNames()
    typecst = parse_expression(typestr)
    typecst = typecst.visit(cleanser)
    clean_type: Optional[cst.CSTNode] = None

    # Now, convert the type to allow for DoNotCareSentinel values.
    if isinstance(typecst, cst.Subscript):
        if typecst.value.deep_equals(cst.Name("Union")):
            # We can modify this as-is to add our type
            clean_type = typecst.with_changes(
                slice=[*typecst.slice, _get_do_not_care()]
            )
        elif typecst.value.deep_equals(cst.Name("Literal")):
            clean_type = _get_wrapped_union_type(typecst, _get_do_not_care())
        elif typecst.value.deep_equals(cst.Name("Sequence")):
            clean_type = _get_wrapped_union_type(typecst, _get_do_not_care())

    elif isinstance(typecst, (cst.Name, cst.SimpleString)):
        clean_type = _get_wrapped_union_type(typecst, _get_do_not_care())

    # Now, clean up the outputted type and return the code it generates. If
    # for some reason we encounter a new node type, raise so we can triage.
    if clean_type is None:
        raise Exception(f"Don't support {typecst}")
    else:
        # First, add DoNotCareSentinel to all sequences, so that a sequence
        # can be defined partially with explicit DoNotCare() values for some
        # slots.
        clean_type = ensure_type(
            clean_type.visit(AddDoNotCareToSequences()), cst.CSTNode
        )
        # Now, double-quote any types we parsed and repr'd, for consistency.
        clean_type = ensure_type(clean_type.visit(DoubleQuoteStrings()), cst.CSTNode)
        # Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage.
        # This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any
        # spot that we would have originally allowed a SomeType.
        clean_type = ensure_type(
            clean_type.visit(AddLogicAndLambdaMatcherToUnions()), cst.CSTNode
        )
        # Now, insert AtMostN and AtLeastN into sequence unions, so we can typecheck
        # them. This relies on the previous OneOf/AllOf insertion to ensure that all
        # sequences we care about are Sequence[Union[<x>]].
        clean_type = ensure_type(
            clean_type.visit(AddWildcardsToSequenceUnions()), cst.CSTNode
        )
        # Finally, generate the code given a default Module so we can spit it out.
        return cst.Module(body=()).code_for_node(clean_type)
 def leave_Return(
     self, original_node: cst.Return, updated_node: cst.Return
 ) -> Union[cst.Return, RemovalSentinel,
            FlattenSentinel[cst.BaseSmallStatement]]:
     return FlattenSentinel([
         cst.Expr(parse_expression("print('returning')")),
         updated_node,
     ])
Exemple #16
0
 def test_generator_return(self) -> None:
     node = libcst.parse_expression("(x for x in foo)").with_changes(
         lpar=[], rpar=[]
     )
     new_node = parenthesize.parenthesize_using_previous(
         node, libcst.parse_statement("return (x for x in foo)")
     )
     self.assert_has_parentheses(new_node)
Exemple #17
0
 def leave_SimpleString(
     self, original_node: libcst.SimpleString, updated_node: libcst.SimpleString
 ) -> Union[libcst.SimpleString, libcst.BaseExpression]:
     AddImportsVisitor.add_needed_import(self.context, "__future__", "annotations")
     # Just use LibCST to evaluate the expression itself, and insert that as the
     # annotation.
     return parse_expression(
         updated_node.evaluated_value, config=self.module.config_for_parsing
     )
Exemple #18
0
 def gen_name_or_attr_node(
         self, dotted_expression: str) -> Union[cst.Attribute, cst.Name]:
     name_or_attr_node: cst.BaseExpression = cst.parse_expression(
         dotted_expression)
     if not isinstance(name_or_attr_node, (cst.Name, cst.Attribute)):
         raise Exception(
             "`parse_expression()` on dotted path returned non-Attribute-or-Name."
         )
     return name_or_attr_node
Exemple #19
0
def extract_names_from_type_annot(type_annot: str):
    """
    Extracts all the names/identifiers from a type annotation
    """

    return [
        n.value for n in match.findall(cst.parse_expression(type_annot),
                                       match.Name(value=match.DoNotCare()))
    ]
Exemple #20
0
 def test_extractall_simple(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)")
     matches = extractall(expression, m.Arg(m.SaveMatchedNode(~m.Name(), "expr")))
     extracted_args = cst.ensure_type(
         cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call
     ).args
     self.assertEqual(
         matches,
         [{"expr": extracted_args[1].value}, {"expr": extracted_args[2].value}],
     )
Exemple #21
0
def parse_arg(arg_str: str) -> Arg:
    """Build a `Arg` instance based on its string representation.

    Instantiating it from scratch is cumbersome, this helper generates a
    function call with the given argument and extract it from the tree.
    """
    call_result = parse_expression(f"call({arg_str})")
    if isinstance(call_result, Call):
        return call_result.args[0]
    raise AssertionError(f"Unexpected type for: {call_result}")
Exemple #22
0
    def leave_ImportFrom(self, original_node: cst.ImportFrom,
                         updated_node: cst.ImportFrom) -> cst.ImportFrom:
        module = updated_node.module
        if module is None:
            return updated_node
        imported_module_name = get_full_name_for_node(module)
        names = original_node.names

        if imported_module_name is None or not isinstance(names, Sequence):
            return updated_node

        else:
            new_names = []
            for import_alias in names:
                alias_name = get_full_name_for_node(import_alias.name)
                if alias_name is not None:
                    qual_name = f"{imported_module_name}.{alias_name}"
                    if self.old_name == qual_name:

                        replacement_module = self.gen_replacement_module(
                            imported_module_name)
                        replacement_obj = self.gen_replacement(alias_name)
                        if not replacement_obj:
                            # The user has requested an `import` statement rather than an `from ... import`.
                            # This will be taken care of in `leave_Module`, in the meantime, schedule for potential removal.
                            new_names.append(import_alias)
                            self.scheduled_removals.add(original_node)
                            continue

                        new_import_alias_name: Union[
                            cst.Attribute,
                            cst.Name] = self.gen_name_or_attr_node(
                                replacement_obj)
                        # Rename on the spot only if this is the only imported name under the module.
                        if len(names) == 1:
                            self.bypass_import = True
                            return updated_node.with_changes(
                                module=cst.parse_expression(
                                    replacement_module),
                                names=(cst.ImportAlias(
                                    name=new_import_alias_name), ),
                            )
                        # Or if the module name is to stay the same.
                        elif replacement_module == imported_module_name:
                            self.bypass_import = True
                            new_names.append(
                                cst.ImportAlias(name=new_import_alias_name))
                    else:
                        if self.old_name.startswith(qual_name + "."):
                            # This import might be in use elsewhere in the code, so schedule a potential removal.
                            self.scheduled_removals.add(original_node)
                        new_names.append(import_alias)

            return updated_node.with_changes(names=new_names)
        return updated_node
Exemple #23
0
 def test_extract_optional_wildcard_tail(self) -> None:
     expression = cst.parse_expression("[3]")
     nodes = m.extract(
         expression,
         m.List(elements=[
             m.Element(value=m.Integer(value="3")),
             m.SaveMatchedNode(m.ZeroOrMore(), "tail1"),
             m.SaveMatchedNode(m.ZeroOrMore(), "tail2"),
         ]),
     )
     self.assertEqual(nodes, {"tail1": (), "tail2": ()})
Exemple #24
0
    def test_string_prefix_and_quotes(self) -> None:
        """
        Test our helpers out for various strings.
        """
        emptybytestring = cst.ensure_type(parse_expression('b""'), cst.SimpleString)
        bytestring = cst.ensure_type(parse_expression('b"abc"'), cst.SimpleString)
        multilinestring = cst.ensure_type(parse_expression('""""""'), cst.SimpleString)
        formatstring = cst.ensure_type(parse_expression('f""""""'), cst.FormattedString)

        self.assertEqual(emptybytestring.prefix, "b")
        self.assertEqual(emptybytestring.quote, '"')
        self.assertEqual(emptybytestring.raw_value, "")
        self.assertEqual(bytestring.prefix, "b")
        self.assertEqual(bytestring.quote, '"')
        self.assertEqual(bytestring.raw_value, "abc")
        self.assertEqual(multilinestring.prefix, "")
        self.assertEqual(multilinestring.quote, '"""')
        self.assertEqual(multilinestring.raw_value, "")
        self.assertEqual(formatstring.prefix, "f")
        self.assertEqual(formatstring.quote, '"""')
 def visit_SimpleString(self, node: cst.SimpleString) -> None:
     if not self.has_future_annotations_import:
         return
     if self.in_annotation and not self.in_literal:
         # This is not allowed past Python3.7 since it's no longer necessary.
         self.report(
             node,
             replacement=cst.parse_expression(
                 node.evaluated_value,
                 config=self.context.wrapper.module.config_for_parsing,
             ),
         )
Exemple #26
0
    def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
        (namedtuple_base, new_bases) = self.partition_bases(original_node.bases)
        if namedtuple_base is not None:
            call = ensure_type(parse_expression("dataclass(frozen=True)"), cst.Call)

            replacement = original_node.with_changes(
                lpar=MaybeSentinel.DEFAULT,
                rpar=MaybeSentinel.DEFAULT,
                bases=new_bases,
                decorators=list(original_node.decorators) + [cst.Decorator(decorator=call)],
            )
            self.report(original_node, replacement=replacement)
Exemple #27
0
    def test_parsing_compilable_expression_strings(self,
                                                   source_code: str) -> None:
        """Much like statements, but for expressions this time.

        We change the start production of the grammar, the compile mode,
        and the libCST parse function, but codegen is as for statements.
        """
        self.reject_invalid_code(source_code, mode="eval")
        tree = libcst.parse_expression(source_code)
        self.verify_identical_asts(source_code,
                                   libcst.Module([]).code_for_node(tree),
                                   mode="eval")
Exemple #28
0
 def test_extract_tautology(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g)")
     nodes = m.extract(
         expression,
         m.SaveMatchedNode(
             m.Tuple(elements=[
                 m.Element(m.BinaryOperation()),
                 m.Element(m.Call())
             ]),
             name="node",
         ),
     )
     self.assertEqual(nodes, {"node": expression})
Exemple #29
0
 def test_extract_sequence(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)")
     nodes = m.extract(
         expression,
         m.Tuple(elements=[
             m.DoNotCare(),
             m.Element(
                 m.Call(args=m.SaveMatchedNode([m.ZeroOrMore()], "args"))),
         ]),
     )
     extracted_seq = cst.ensure_type(
         cst.ensure_type(expression, cst.Tuple).elements[1].value,
         cst.Call).args
     self.assertEqual(nodes, {"args": extracted_seq})
Exemple #30
0
 def update_call_args(self, node: Call) -> Sequence[Arg]:
     """Update first argument to convert integer for minutes to timedelta."""
     AddImportsVisitor.add_needed_import(
         context=self.context,
         module="datetime",
         obj="timedelta",
     )
     offset_arg, *other_args = node.args
     integer_value = offset_arg.value
     if not isinstance(integer_value, Integer):
         raise AssertionError(f"Unexpected type for: {integer_value}")
     timedelta_call = parse_expression(f"timedelta(minutes={integer_value.value})")
     new_offset_arg = offset_arg.with_changes(value=timedelta_call)
     return (new_offset_arg, *other_args)