def test_parsers(self, code: str, expected_module: cst.CSTNode) -> None:
     parsed_module = parse_module(dedent(code))
     self.assertTrue(
         deep_equals(parsed_module, expected_module),
         msg=
         f"\n{parsed_module!r}\nis not deeply equal to \n{expected_module!r}",
     )
Example #2
0
 def test_parsers(
     self,
     parser: Callable[[Config, State], _T],
     config: Config,
     start_state: State,
     end_state: State,
     expected_node: _T,
 ) -> None:
     # Uses internal `deep_equals` function instead of `CSTNode.deep_equals`, because
     # we need to compare sequences of nodes, and this is the easiest way. :/
     parsed_node = parser(config, start_state)
     self.assertTrue(
         deep_equals(parsed_node, expected_node),
         msg=f"\n{parsed_node!r}\nis not deeply equal to \n{expected_node!r}",
     )
     self.assertEqual(start_state, end_state)
Example #3
0
def _node_repr_recursive(  # noqa: C901
    node: object,
    *,
    indent: str = _DEFAULT_INDENT,
    show_defaults: bool = False,
    show_syntax: bool = False,
    show_whitespace: bool = False,
) -> List[str]:
    if isinstance(node, CSTNode):
        # This is a CSTNode, we must pretty-print it.
        tokens: List[str] = [node.__class__.__name__]
        fields: Sequence["dataclasses.Field[object]"] = dataclasses.fields(
            node)

        # Hide all fields prefixed with "_"
        fields = [f for f in fields if f.name[0] != "_"]

        # Filter whitespace nodes if needed
        if not show_whitespace:

            def _is_whitespace(field: "dataclasses.Field[object]") -> bool:
                if "whitespace" in field.name:
                    return True
                if "leading_lines" in field.name:
                    return True
                if "lines_after_decorators" in field.name:
                    return True
                if isinstance(node,
                              (IndentedBlock, Module)) and field.name in [
                                  "header",
                                  "footer",
                              ]:
                    return True
                if isinstance(node, IndentedBlock) and field.name == "indent":
                    return True
                return False

            fields = [f for f in fields if not _is_whitespace(f)]
        # Filter values which aren't changed from their defaults
        if not show_defaults:

            def _get_default(fld: "dataclasses.Field[object]") -> object:
                if fld.default_factory is not dataclasses.MISSING:
                    return fld.default_factory()
                return fld.default

            fields = [
                f for f in fields
                if not deep_equals(getattr(node, f.name), _get_default(f))
            ]
        # Filter out values which aren't interesting if needed
        if not show_syntax:

            def _is_syntax(field: "dataclasses.Field[object]") -> bool:
                if isinstance(node, Module) and field.name in [
                        "encoding",
                        "default_indent",
                        "default_newline",
                        "has_trailing_newline",
                ]:
                    return True
                type_str = repr(field.type)
                if ("Sentinel" in type_str and field.name
                        not in ["star_arg", "star", "posonly_ind"]
                        and "whitespace" not in field.name):
                    # This is a value that can optionally be specified, so its
                    # definitely syntax.
                    return True

                for name in [
                        "Semicolon", "Colon", "Comma", "Dot", "AssignEqual"
                ]:
                    # These are all nodes that exist for separation syntax
                    if name in type_str:
                        return True

                return False

            fields = [f for f in fields if not _is_syntax(f)]

        if len(fields) == 0:
            tokens.append("()")
        else:
            tokens.append("(\n")

            for field in fields:
                child_tokens: List[str] = [field.name, "="]
                value = getattr(node, field.name)

                if isinstance(value,
                              (str, bytes)) or not isinstance(value, Sequence):
                    # Render out the node contents
                    child_tokens.extend(
                        _node_repr_recursive(
                            value,
                            show_whitespace=show_whitespace,
                            show_defaults=show_defaults,
                            show_syntax=show_syntax,
                        ))
                elif isinstance(value, Sequence):
                    # Render out a list of individual nodes
                    if len(value) > 0:
                        child_tokens.append("[\n")
                        list_tokens: List[str] = []

                        last_value = len(value) - 1
                        for j, v in enumerate(value):
                            list_tokens.extend(
                                _node_repr_recursive(
                                    v,
                                    show_whitespace=show_whitespace,
                                    show_defaults=show_defaults,
                                    show_syntax=show_syntax,
                                ))
                            if j != last_value:
                                list_tokens.append(",\n")
                            else:
                                list_tokens.append(",")

                        split_by_line = "".join(list_tokens).split("\n")
                        child_tokens.append("\n".join(f"{indent}{t}"
                                                      for t in split_by_line))

                        child_tokens.append("\n]")
                    else:
                        child_tokens.append("[]")
                else:
                    raise Exception("Logic error!")

                # Handle indentation and trailing comma.
                split_by_line = "".join(child_tokens).split("\n")
                tokens.append("\n".join(f"{indent}{t}" for t in split_by_line))
                tokens.append(",\n")

            tokens.append(")")

        return tokens
    else:
        # This is a python value, just return the repr
        return [repr(node)]