Example #1
0
 def exec(self,
         etree: tp.Union[cst.ClassDef, cst.FunctionDef],
         stree: tp.Union[cst.ClassDef, cst.FunctionDef],
         env: SymbolTable):
     emod = cst.Module(body=(etree,))
     smod = cst.Module(body=(stree,))
     st = exec_str_in_file(emod.code, env, self.path, self.file_name, smod.code)
     return st[etree.name.value]
Example #2
0
def to_module(node: _T) -> cst.Module:
    if isinstance(node, cst.SimpleStatementSuite):
        return cst.Module(body=node.body)
    elif isinstance(node, cst.IndentedBlock):
        return cst.Module(body=node.body)

    if isinstance(node, cst.BaseExpression):
        node = cst.Expr(value=node)

    if isinstance(node, (cst.BaseStatement, cst.BaseSmallStatement)):
        node = cst.Module(body=(node, ))

    if isinstance(node, cst.Module):
        return node

    raise TypeError(f'{node} :: {type(node)} cannot be cast to Module')
Example #3
0
def _get_expression_transform(
        before: Callable[..., Any],
        after: Callable[..., Any]) -> ExpressionTransform:
    expression = function_parser.parse(before)[0]
    matchers = function_parser.args_to_matchers(before)
    matcher = craftier.matcher.from_node(expression, matchers)
    inner_matcher = getattr(matcher, "matcher", None)
    if isinstance(matcher, libcst.matchers.DoNotCareSentinel) or isinstance(
            inner_matcher, libcst.matchers.DoNotCareSentinel):
        raise Exception(
            f"DoNotCare matcher is forbidden at top level in `{before.__name__}`"
        )

    after_expression = function_parser.parse(after)[0]
    # Technically this is not correct as some expressions,like function calls,
    # binary operations, name, etc, are wrapped in an `Expr` node.
    module = libcst.Module(body=[
        libcst.SimpleStatementLine(
            body=[cast(libcst.BaseSmallStatement, after_expression)])
    ])
    wrapper = libcst.metadata.MetadataWrapper(module)
    body = cast(libcst.SimpleStatementLine, wrapper.module.body[0])
    replacement = body.body[0]

    return ExpressionTransform(before=matcher,
                               after=replacement,
                               wrapper=wrapper)
Example #4
0
    def __assert_codegen(
        self,
        node: cst.CSTNode,
        expected: str,
        expected_position: Optional[CodeRange] = None,
    ) -> None:
        """
        Verifies that the given node's `_codegen` method is correct.
        """
        module = cst.Module([])
        self.assertEqual(module.code_for_node(node), expected)

        if expected_position is not None:
            # This is using some internal APIs, because we only want to compute
            # position for the node being tested, not a whole module.
            #
            # Normally, this is a nonsense operation (how can a node have a position if
            # its not in a module?), which is why it's not supported, but it makes
            # sense in the context of these node tests.
            provider = PositionProvider()
            state = PositionProvidingCodegenState(
                default_indent=module.default_indent,
                default_newline=module.default_newline,
                provider=provider,
            )
            node._codegen(state)
            self.assertEqual(provider._computed[node], expected_position)
Example #5
0
    def source(self) -> str:
        # try:
        module = cst.Module(list(self.body))
        # except libcst.CSTValidationError:
        #     raise RuntimeError(f"Could not output source for:\n{self!r}")

        return module.code
Example #6
0
def _get_clean_type(typeobj: object) -> str:
    """
    Given a type object as returned by dataclasses, sanitize it and convert it
    to a type string that is appropriate for our codegen below.
    """

    # First, get the type as a parseable expression.
    typestr = repr(typeobj)
    if typestr.startswith("<class '") and typestr.endswith("'>"):
        typestr = typestr[8:-2]

    # Now, parse the expression with LibCST.
    cleanser = CleanseFullTypeNames()
    typecst = parse_expression(typestr)
    typecst = typecst.visit(cleanser)
    clean_type: Optional[cst.CSTNode] = None

    # Now, convert the type to allow for DoNotCareSentinel values.
    if isinstance(typecst, cst.Subscript):
        if typecst.value.deep_equals(cst.Name("Union")):
            # We can modify this as-is to add our type
            clean_type = typecst.with_changes(
                slice=[*typecst.slice, _get_do_not_care()]
            )
        elif typecst.value.deep_equals(cst.Name("Literal")):
            clean_type = _get_wrapped_union_type(typecst, _get_do_not_care())
        elif typecst.value.deep_equals(cst.Name("Sequence")):
            clean_type = _get_wrapped_union_type(typecst, _get_do_not_care())

    elif isinstance(typecst, (cst.Name, cst.SimpleString)):
        clean_type = _get_wrapped_union_type(typecst, _get_do_not_care())

    # Now, clean up the outputted type and return the code it generates. If
    # for some reason we encounter a new node type, raise so we can triage.
    if clean_type is None:
        raise Exception(f"Don't support {typecst}")
    else:
        # First, add DoNotCareSentinel to all sequences, so that a sequence
        # can be defined partially with explicit DoNotCare() values for some
        # slots.
        clean_type = ensure_type(
            clean_type.visit(AddDoNotCareToSequences()), cst.CSTNode
        )
        # Now, double-quote any types we parsed and repr'd, for consistency.
        clean_type = ensure_type(clean_type.visit(DoubleQuoteStrings()), cst.CSTNode)
        # Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage.
        # This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any
        # spot that we would have originally allowed a SomeType.
        clean_type = ensure_type(
            clean_type.visit(AddLogicAndLambdaMatcherToUnions()), cst.CSTNode
        )
        # Now, insert AtMostN and AtLeastN into sequence unions, so we can typecheck
        # them. This relies on the previous OneOf/AllOf insertion to ensure that all
        # sequences we care about are Sequence[Union[<x>]].
        clean_type = ensure_type(
            clean_type.visit(AddWildcardsToSequenceUnions()), cst.CSTNode
        )
        # Finally, generate the code given a default Module so we can spit it out.
        return cst.Module(body=()).code_for_node(clean_type)
Example #7
0
    def __derive_children_from_codegen(
            self, node: cst.CSTNode) -> Sequence[cst.CSTNode]:
        """
        Patches all subclasses of `CSTNode` exported by the `cst` module to track which
        `_codegen` methods get called, generating a list of children.

        Because all children must be rendered out into lexical order, this should be
        equivalent to `node.children`.

        `node.children` uses `_visit_and_replace_children` under the hood, not
        `_codegen`, so this helps us verify that both of those two method's behaviors
        are in sync.
        """

        patch_targets: Iterable[_CSTCodegenPatchTarget] = [
            _CSTCodegenPatchTarget(type=v, name=k, old_codegen=v._codegen)
            for (k, v) in cst.__dict__.items() if isinstance(v, type)
            and issubclass(v, cst.CSTNode) and hasattr(v, "_codegen")
        ]

        children: List[cst.CSTNode] = []
        codegen_stack: List[cst.CSTNode] = []

        def _get_codegen_override(
                target: _CSTCodegenPatchTarget) -> Callable[..., None]:
            def _codegen_impl(self: CSTNodeT, *args: Any,
                              **kwargs: Any) -> None:
                should_pop = False
                # Don't stick duplicates in the stack. This is needed so that we don't
                # track calls to `super()._codegen()`.
                if len(codegen_stack) == 0 or codegen_stack[-1] is not self:
                    # Check the stack to see that we're a direct child, not the root or
                    # a transitive child.
                    if len(codegen_stack) == 1:
                        children.append(self)
                    codegen_stack.append(self)
                    should_pop = True
                target.old_codegen(self, *args, **kwargs)
                # only pop if we pushed something to the stack earlier
                if should_pop:
                    codegen_stack.pop()

            return _codegen_impl

        with ExitStack() as patch_stack:
            for t in patch_targets:
                patch_stack.enter_context(
                    # pyre-ignore Incompatible parameter type [6]: Expected
                    # pyre-ignore `typing.ContextManager[Variable[contextlib._T]]`
                    # pyre-ignore for 1st anonymous parameter to call
                    # pyre-ignore `contextlib.ExitStack.enter_context` but got
                    # pyre-ignore `unittest.mock._patch`.
                    patch(f"libcst.{t.name}._codegen",
                          _get_codegen_override(t)))
            # Execute `node._codegen()`
            cst.Module([]).code_for_node(node)

        return children
Example #8
0
    def __convert_node_to_code(self, node) -> str:
        """
        Converts a node to a code string.
        """
        # Construct artificial module from single node
        node_module = cst.Module([node])

        # Return the code representation of that module
        return node_module.code
Example #9
0
 def test_is_sortable(self) -> None:
     sorter = ImportSorter(module=cst.Module([]),
                           path=Path(),
                           config=Config())
     self.assertTrue(sorter.is_sortable_import(parse_import("import a")))
     self.assertTrue(
         sorter.is_sortable_import(parse_import("from a import b")))
     self.assertFalse(
         sorter.is_sortable_import(parse_import("import a  # isort: skip")))
Example #10
0
    def __str__(self):
        header = cst.Module(body=[], header=self.tree.header).code.strip()
        code = "\n".join(
            str(line)
            for line in chain.from_iterable(zip(self.statements, self.results))
            if line)
        footer = cst.Module(body=[], footer=self.tree.footer).code.strip()
        out = "\n".join([header, code, footer])
        if self.terminal:
            try:
                from pygments import highlight
                from pygments.formatters import Terminal256Formatter
                from pygments.lexers import PythonLexer

                out = highlight(out, PythonLexer(),
                                Terminal256Formatter(style="friendly"))
            except ImportError:
                pass
        return out.strip()
Example #11
0
def _wrap_clean_type(aliases: List[Alias], name: Optional[str],
                     value: cst.Subscript) -> cst.BaseExpression:
    if name is not None:
        # We created an alias, lets use that, wrapping the alias in a do not care.
        aliases.append(
            Alias(name=name, type=cst.Module(body=()).code_for_node(value)))
        return _get_wrapped_union_type(cst.Name(name), _get_do_not_care())
    else:
        # Couldn't name the alias, fall back to regular node creation, add do not
        # care to the resulting type we widened.
        return value.with_changes(slice=[*value.slice, _get_do_not_care()])
Example #12
0
 def visit_AnnAssign(self, node):
   if not node.value:
     # TODO(b/167613685): Stop discarding annotations without values.
     return
   pos = self._get_position(node)
   # Gets a string representation of the annotation.
   annotation = re.sub(
       r"\s*(#.*)?\n\s*", "",
       libcst.Module([node.annotation.annotation]).code)
   self.variable_annotations.append(
       _VariableAnnotation(pos.start.line, pos.end.line, annotation))
Example #13
0
    def test_deprecated_non_element_construction(self) -> None:
        module = cst.Module(body=[
            cst.SimpleStatementLine(body=[
                cst.Expr(value=cst.Subscript(
                    value=cst.Name(value="foo"),
                    slice=cst.Index(value=cst.Integer(value="1")),
                ))
            ])
        ])

        self.assertEqual(module.code, "foo[1]\n")
Example #14
0
    def leave_Module(self, original_node, updated_node):
        final_node = super().leave_Module(original_node, updated_node)
        imports_str = cst.Module(
            body=[cst.SimpleStatementLine([i]) for i in self.imports]).code
        sorted_imports = cst.parse_module(
            SortImports(file_contents=imports_str).output)

        # Add imports back to the top of the module
        new_body = sorted_imports.body + list(final_node.body)

        return final_node.with_changes(body=new_body)
Example #15
0
    def test_parsing_compilable_expression_strings(self,
                                                   source_code: str) -> None:
        """Much like statements, but for expressions this time.

        We change the start production of the grammar, the compile mode,
        and the libCST parse function, but codegen is as for statements.
        """
        self.reject_invalid_code(source_code, mode="eval")
        tree = libcst.parse_expression(source_code)
        self.verify_identical_asts(source_code,
                                   libcst.Module([]).code_for_node(tree),
                                   mode="eval")
Example #16
0
    def test_parsing_compilable_statement_strings(self,
                                                  source_code: str) -> None:
        """Just like the above, but for statements.

        We change the start production of the grammar, the compile mode,
        the libCST parse function, and the codegen method.
        """
        self.reject_invalid_code(source_code, mode="single")
        tree = libcst.parse_statement(source_code)
        self.verify_identical_asts(source_code,
                                   libcst.Module([]).code_for_node(tree),
                                   mode="single")
Example #17
0
 def leave_Assert(self, _, updated_node):  # noqa
     test_code = cst.Module("").code_for_node(updated_node.test)
     try:
         test_literal = literal_eval(test_code)
     except Exception:
         return updated_node
     if test_literal:
         return cst.RemovalSentinel.REMOVE
     if updated_node.msg is None:
         return cst.Raise(cst.Name("AssertionError"))
     return cst.Raise(
         cst.Call(cst.Name("AssertionError"),
                  args=[cst.Arg(updated_node.msg)]))
Example #18
0
    def __str__(self) -> str:
        code = cst.Module(body=[self.stmt]).code
        if self.style:
            try:
                from black import Mode, format_str
            except ImportError:
                raise ImportError("Must install black to restyle code.")

            code = format_str(code, mode=Mode())
        if code.endswith("\n"):
            # Strip trailing newline without stripping deliberate ones.
            code = code[:-1]
        return code
Example #19
0
def from_node(node: Type[libcst.CSTNode] = libcst.Module,
              *,
              auto_target: bool = True) -> st.SearchStrategy[str]:
    """Generate syntactically-valid Python source code for a LibCST node type.

    You can pass any subtype of `libcst.CSTNode`.  Alternatively, you can use
    Hypothesis' built-in `from_type(node_type).map(lambda n: libcst.Module([n]).code`,
    after Hypothesmith has registered the required strategies.  However, this does
    not include automatic targeting and limitations of LibCST may lead to invalid
    code being generated.
    """
    assert issubclass(node, libcst.CSTNode)
    code = st.from_type(node).map(lambda n: libcst.Module([n]).code).filter(
        compilable)
    return code.map(record_targets) if auto_target else code
Example #20
0
    def test_circular_dependency(self) -> None:
        """
        Tests that circular dependencies are detected.
        """
        class ProviderA(VisitorMetadataProvider[str]):
            pass

        ProviderA.METADATA_DEPENDENCIES = (ProviderA, )

        class BadVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (ProviderA, )

        with self.assertRaisesRegex(
                MetadataException,
                "Detected circular dependencies in ProviderA"):
            cst.Module([]).visit(BadVisitor())
Example #21
0
    def test_unset_metadata(self) -> None:
        """
        Tests that access to unset metadata throws a key error.
        """
        class ProviderA(VisitorMetadataProvider[bool]):
            pass

        class AVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (ProviderA, )

            def on_visit(self, node: cst.CSTNode) -> bool:
                self.get_metadata(ProviderA, node)
                return True

        with self.assertRaises(KeyError):
            cst.Module([]).visit(AVisitor())
Example #22
0
    def __assert_codegen(
        self,
        node: cst.CSTNode,
        expected: str,
        expected_position: Optional[CodeRange] = None,
    ) -> None:
        """
        Verifies that the given node's `_codegen` method is correct.
        """
        module = cst.Module([])
        provider = None if expected_position is None else SyntacticPositionProvider()

        self.assertEqual(module.code_for_node(node, provider=provider), expected)

        if provider is not None:
            self.assertEqual(provider._computed[node], expected_position)
Example #23
0
    def test_deep_replace_identity(self) -> None:
        old_code = """
            pass
        """
        new_code = """
            break
        """

        module = cst.parse_module(dedent(old_code))
        new_module = module.deep_replace(
            module,
            cst.Module(
                header=(cst.EmptyLine(), ),
                body=(cst.SimpleStatementLine(body=(cst.Break(), )), ),
            ),
        )
        self.assertEqual(new_module.code, dedent(new_code))
Example #24
0
def test_source_code_from_libcst_node_type(node, data):
    try:
        val = data.draw(st.from_type(node))
    except NameError:
        pytest.skip("NameError, probably a forward reference")
    except TypeError as e:
        if str(e).startswith("super"):
            pytest.skip("something weird here, back later")
        if str(e).startswith("Can't instantiate"):
            pytest.skip("abstract classes somehow leaking into builds()")
        raise
    note(val)
    if not isinstance(val, libcst.Module):
        val = libcst.Module([val])
    try:
        code = val.code
    except libcst._nodes.base.CSTCodegenError:
        pytest.skip("codegen not supported yet, e.g. Annotation")
    note(code)
Example #25
0
 def _get_assert_replacement(self, node: cst.Assert):
     message = node.msg or str(cst.Module(body=[node]).code)
     return cst.If(
         test=cst.UnaryOperation(
             operator=cst.Not(),
             expression=node.test,  # Todo: parenthesize?
         ),
         body=cst.IndentedBlock(body=[
             cst.SimpleStatementLine(body=[
                 cst.Raise(exc=cst.Call(
                     func=cst.Name(value="AssertionError", ),
                     args=[
                         cst.Arg(value=cst.SimpleString(value=repr(message),
                                                        ), ),
                     ],
                 ), ),
             ]),
         ], ),
     )
Example #26
0
    def test_self_metadata(self) -> None:
        """
        Tests a provider can access its own metadata (assuming it has been
        set properly.)
        """
        test_runner = self

        class ProviderA(VisitorMetadataProvider[bool]):
            def on_visit(self, node: cst.CSTNode) -> bool:
                self.set_metadata(node, True)
                return True

            def on_leave(self, original_node: cst.CSTNode) -> None:
                test_runner.assertEqual(
                    self.get_metadata(type(self), original_node), True)

        class AVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (ProviderA, )

        cst.Module([]).visit(AVisitor())
Example #27
0
 def test_adding_parens(self) -> None:
     node = cst.With(
         (
             cst.WithItem(
                 cst.Call(cst.Name("foo")),
                 comma=cst.Comma(
                     whitespace_after=cst.ParenthesizedWhitespace(), ),
             ),
             cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()),
         ),
         cst.SimpleStatementSuite((cst.Pass(), )),
         lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
         rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),
     )
     module = cst.Module([])
     self.assertEqual(
         module.code_for_node(node),
         ("with ( foo(),\n"
          "bar(), ): pass\n")  # noqa
     )
Example #28
0
    def test_undeclared_metadata(self) -> None:
        """
        Tests that access to undeclared metadata throws a key error.
        """
        class ProviderA(VisitorMetadataProvider[bool]):
            pass

        class ProviderB(VisitorMetadataProvider[bool]):
            pass

        class AVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (ProviderA, )

            def on_visit(self, node: cst.CSTNode) -> bool:
                self.get_metadata(ProviderA, node, True)
                self.get_metadata(ProviderB, node)
                return True

        with self.assertRaisesRegex(
                KeyError,
                "ProviderB is not declared as a dependency from AVisitor"):
            cst.Module([]).visit(AVisitor())
Example #29
0
def _get_clean_type_and_aliases(
    typeobj: object,
) -> Tuple[str, List[Alias]]:  # noqa: C901
    """
    Given a type object as returned by dataclasses, sanitize it and convert it
    to a type string that is appropriate for our codegen below.
    """

    # First, get the type as a parseable expression.
    typestr = repr(typeobj)
    if typestr.startswith("<class '") and typestr.endswith("'>"):
        typestr = typestr[8:-2]

    # Now, parse the expression with LibCST.
    cleanser = CleanseFullTypeNames()
    typecst = parse_expression(typestr)
    typecst = typecst.visit(cleanser)
    aliases: List[Alias] = []

    # Now, convert the type to allow for MetadataMatchType and MatchIfTrue values.
    if isinstance(typecst, cst.Subscript):
        clean_type = _get_clean_type_from_subscript(aliases, typecst)
    elif isinstance(typecst, (cst.Name, cst.SimpleString)):
        clean_type = _get_clean_type_from_expression(aliases, typecst)
    else:
        raise Exception("Logic error, unexpected top level type!")

    # Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage.
    # This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any
    # spot that we would have originally allowed a SomeType.
    clean_type = ensure_type(clean_type.visit(AddLogicMatchersToUnions()), cst.CSTNode)
    # Now, insert AtMostN and AtLeastN into sequence unions, so we can typecheck
    # them. This relies on the previous OneOf/AllOf insertion to ensure that all
    # sequences we care about are Sequence[Union[<x>]].
    clean_type = ensure_type(
        clean_type.visit(AddWildcardsToSequenceUnions()), cst.CSTNode
    )
    # Finally, generate the code given a default Module so we can spit it out.
    return cst.Module(body=()).code_for_node(clean_type), aliases
Example #30
0
def _add_error_to_line_break_block(lines: List[str],
                                   errors: List[List[str]]) -> None:
    # Gather unbroken lines.
    line_break_block = [lines.pop() for _ in range(0, len(errors))]
    line_break_block.reverse()

    # Transform line break block to use parenthesis.
    indent = len(line_break_block[0]) - len(line_break_block[0].lstrip())
    line_break_block = [line[indent:] for line in line_break_block]
    statement = "\n".join(line_break_block)
    transformed_statement = libcst.Module([]).code_for_node(
        cast(
            libcst.CSTNode,
            libcst.parse_statement(statement).visit(LineBreakTransformer()),
        ))
    transformed_lines = transformed_statement.split("\n")
    transformed_lines = [" " * indent + line for line in transformed_lines]

    # Add to lines.
    for line, comment in zip(transformed_lines, errors):
        lines.extend(comment)
        lines.append(line)