def test_multiline_alias(self, init, indent): init.return_value = None transformer = transform.ImportTransformer(None, None) transformer._indentation = indent alias = transformer._multiline_alias( cst.ImportAlias(name=cst.Name("x"))) assert alias.comma.whitespace_after.last_line.value == indent + " " * 4
def _assert_import_equal( self, impt_stmnt: str, endlineno: int, used_names: set, expec_impt: str ): location = NodeLocation((1, 0), endlineno) transformer = transform.ImportTransformer(used_names, location) cst_tree = cst.parse_module(impt_stmnt) assert cst_tree.visit(transformer).code == expec_impt
def test_stylize(self, code, endlineno, ismultiline): location = NodeLocation((1, 0), endlineno) node = cst.parse_module(code).body[0].body[0] transformer = transform.ImportTransformer({""}, location) new_node = transformer._stylize(node, node.names, False) if getattr(new_node, "rpar", None) and ismultiline: assert new_node.rpar != node.rpar assert new_node.lpar != node.lpar assert new_node.names[-1].comma == cst.MaybeSentinel.DEFAULT
def test_get_alias_name(self, init, name): init.return_value = None def get_name_node(name: str) -> Union[cst.Name, cst.Attribute]: # Inverse `_get_alias_name`. if "." not in name: return cst.Name(name) names = name.split(".") value = get_name_node(".".join(names[:-1])) attr = get_name_node(names[-1]) return cst.Attribute(value=value, attr=attr) # type: ignore node = get_name_node(name) transformer = transform.ImportTransformer(None, None) assert transformer._get_alias_name(node) == name
def test_multiline_rpar(self, init, indent): init.return_value = None transformer = transform.ImportTransformer(None, None) transformer._indentation = indent rpar = transformer._multiline_rpar() assert rpar.whitespace_before.last_line.value == indent
def test_multiline_parenthesized_whitespace(self, init, indent): init.return_value = None transformer = transform.ImportTransformer(None, None) mpw = transformer._multiline_parenthesized_whitespace(indent) assert mpw.last_line.value == indent
def test_init(self, used_names, location, expec_err): with pytest.raises(expec_err): transform.ImportTransformer(used_names, location) raise sysu.Pass()