def test_to_file(self) -> None: """ Tests whether `file` constructs a file, and fills it with the right content """ with TemporaryDirectory() as tempdir: filename = os.path.join( tempdir, "delete_me{extsep}py".format(extsep=extsep)) try: emit.file(class_ast, filename, skip_black=True) with open(filename, "rt") as f: ugly = f.read() os.remove(filename) emit.file(class_ast, filename, skip_black=False) with open(filename, "rt") as f: blacked = f.read() self.assertNotEqual(ugly + "" if "black" in modules else "\t", blacked) # if PY3_8: self.assertTrue( cmp_ast(ast.parse(ugly), ast.parse(blacked)), "Ugly AST doesn't match blacked AST", ) finally: if os.path.isfile(filename): os.remove(filename)
def doctrans(filename, docstring_format, type_annotations): """ Transform the docstrings found within provided filename to intended docstring_format :param filename: Python file to convert docstrings within. Edited in place. :type filename: ```str``` :param docstring_format: Format of docstring :type docstring_format: ```Literal['rest', 'numpydoc', 'google']``` :param type_annotations: True to have type annotations (3.6+), False to place in docstring :type type_annotations: ```bool``` """ with open(filename, "rt") as f: node = ast_parse(f.read(), skip_docstring_remit=False) orig_node = deepcopy(node) node = DocTrans( docstring_format=docstring_format, type_annotations=type_annotations, existing_type_annotations=has_type_annotations(node), whole_ast=orig_node, ).visit(node) if not cmp_ast(node, orig_node): emit.file(node, filename, mode="wt", skip_black=True)
def test_get_at_root(self) -> None: """Tests that `get_at_root` successfully gets the imports""" with open( path.join( path.dirname(__file__), "mocks", "eval{extsep}py".format(extsep=extsep) ) ) as f: imports = get_at_root(ast.parse(f.read()), (Import, ImportFrom)) self.assertIsInstance(imports, list) self.assertEqual(len(imports), 1) self.assertTrue( cmp_ast( imports[0], ast.Import( names=[ ast.alias( asname=None, name="cdd.tests.mocks", identifier=None, identifier_name=None, ) ], alias=None, ), ) )
def test_parse_to_scalar(self) -> None: """Test various inputs and outputs for `parse_to_scalar`""" for fst, snd in ( (5, 5), ("5", "5"), (set_value(5), 5), (ast.Expr(None), NoneStr), ): self.assertEqual(parse_to_scalar(fst), snd) self.assertEqual( get_value(parse_to_scalar(ast.parse("[5]").body[0]).elts[0]), 5 ) self.assertTrue( cmp_ast( parse_to_scalar(ast.parse("[5]").body[0]), List([set_value(5)], Load()), ) ) self.assertEqual(parse_to_scalar(ast.parse("[5]")), "[5]") parse_to_scalar(ast.parse("[5]").body[0]) self.assertRaises(NotImplementedError, parse_to_scalar, memoryview(b"")) self.assertRaises(NotImplementedError, parse_to_scalar, memoryview(b""))
def _conform_filename( filename, search, emit_func, replacement_node_ir, type_wanted, ): """ Conform the given file to the `intermediate_repr` :param filename: Location of file :type filename: ```str``` :param search: Search query, e.g., ['node_name', 'function_name', 'arg_name'] :type search: ```List[str]``` :param replacement_node_ir: Replace what is found with the contents of this param :type replacement_node_ir: ```dict``` :param type_wanted: AST instance :type type_wanted: ```AST``` :returns: filename, whether the file was modified :rtype: ```Tuple[str, bool]``` """ filename = path.realpath(path.expanduser(filename)) if not path.isfile(filename): emit.file( emit_func( replacement_node_ir, emit_default_doc=False, # emit_func.__name__ == "class_" ), filename=filename, mode="wt", skip_black=False, ) return filename, True with open(filename, "rt") as f: parsed_ast = ast_parse(f.read(), filename=filename) assert isinstance(parsed_ast, Module) original_node = find_in_ast(search, parsed_ast) replacement_node = emit_func( replacement_node_ir, **_default_options(node=original_node, search=search, type_wanted=type_wanted)()) if original_node is None: emit.file(replacement_node, filename=filename, mode="a", skip_black=False) return filename, True assert len(search) > 0 assert (type(replacement_node) == type_wanted ), "Expected {type_wanted!r} got {type_replacement_node!r}".format( type_wanted=type_wanted, type_replacement_node=type(replacement_node).__name__) replaced = False if not cmp_ast(original_node, replacement_node): rewrite_at_query = RewriteAtQuery( search=search, replacement_node=replacement_node, ) rewrite_at_query.visit(parsed_ast) print("modified" if rewrite_at_query.replaced else "unchanged", filename, sep="\t") if rewrite_at_query.replaced: emit.file(parsed_ast, filename, mode="wt", skip_black=False) replaced = rewrite_at_query.replaced return filename, replaced
def run_ast_test(test_case_instance, gen_ast, gold, skip_black=False): """ Compares `gen_ast` with `gold` standard :param test_case_instance: instance of `TestCase` :type test_case_instance: ```unittest.TestCase``` :param gen_ast: generated AST :type gen_ast: ```Union[ast.Module, ast.ClassDef, ast.FunctionDef]``` :param skip_black: Whether to skip formatting with black. Turned off for performance, turn on for pretty debug. :type skip_black: ```bool``` :param gold: mocked AST :type gold: ```Union[ast.Module, ast.ClassDef, ast.FunctionDef]``` """ if isinstance(gen_ast, str): gen_ast = ast.parse(gen_ast).body[0] assert gen_ast is not None, "gen_ast is None" assert gold is not None, "gold is None" gen_ast = deepcopy(gen_ast) gold = deepcopy(gold) # if reindent_docstring: # gen_docstring = ast.get_docstring(gen_ast) # if gen_docstring is not None: # gen_ast.body[0] = set_value( # "\n{}".format(indent(cleandoc(gen_docstring), tab)) # ) # gold.body[0] = set_value( # "\n{}".format(indent(ast.get_docstring(gold, clean=True), tab)) # ) _gen_ast, _gold_ast = ( (gen_ast.body[0], gold.body[0]) if isinstance(gen_ast, ast.Module) and gen_ast.body else (gen_ast, gold) ) if isinstance(_gen_ast, (ast.ClassDef, ast.AsyncFunctionDef, ast.FunctionDef)): test_case_instance.assertEqual( *map(partial(ast.get_docstring, clean=False), (_gen_ast, _gold_ast)) ) test_case_instance.assertEqual( *map( identity if skip_black else partial( black.format_str, mode=black.Mode( target_versions=set(), line_length=60, is_pyi=False, string_normalization=False, ), ), map(source_transformer.to_code, (gen_ast, gold)), ) ) # if not cmp_ast(gen_ast, gold): # from meta.asttools import print_ast # # print("#gen") # print_ast(gen_ast) # print("#gold") # print_ast(gold) test_case_instance.assertTrue( cmp_ast(gen_ast, gold), "Generated AST doesn't match reference AST" )
def test_cmp_ast(self) -> None: """Test `cmp_ast` branch that isn't tested anywhere else""" self.assertFalse(cmp_ast(None, 5))