def transpile(in_path: str, out_path: str, backend: str = default_backend, dump_sir: bool = False) -> None: with open(in_path, "r") as in_file: in_str = in_file.read() in_ast = ast.parse(in_str, filename=in_path, type_comments=True) grammar = Grammar() # TODO: handle errors in different stencils separately stencils = [grammar.stencil(node) for node in iter_stencils(in_ast)] sir = make_sir(in_path, GridType.Value("Unstructured"), stencils) if dump_sir: out_name = os.path.splitext(out_path)[0] with open(out_name + ".json", "w+") as f: f.write(sir_to_json(sir)) out_code = compile(sir, backend=backend_map[backend]) with open(out_path, "w") as out_file: out_file.write(out_code)
def pyast_to_sir(stencils: List[ast.FunctionDef], filename: str = "<unknown>") -> SIR: grammar = Grammar() # TODO: should probably throw instead assert all(grammar.is_stencil(stencil) for stencil in stencils) # TODO: handle errors in different stencils separately stencils = [grammar.stencil(stencil) for stencil in stencils] return make_sir(filename, GridType.Value("Unstructured"), stencils)
def transpile_and_validate(stencil: Callable) -> None: stencil = ast.parse(getsource(stencil)) assert isinstance(stencil, ast.Module) assert len(stencil.body) == 1 stencil = stencil.body[0] assert Grammar.is_stencil(stencil) sir = make_sir(__file__, GridType.Value("Unstructured"), [Grammar().stencil(stencil)]) run_optimizer_sir(sir.SerializeToString())
def callable_to_pyast(stencil: Callable, filename: str = "<unknown>") -> List[ast.FunctionDef]: # TODO: this will give wrong line numbers, there should be a way to fix them source = getsource(stencil) stencil_ast = ast.parse(source, filename=filename, type_comments=True) assert isinstance(stencil_ast, ast.Module) assert len(stencil_ast.body) == 1 assert Grammar.is_stencil(stencil_ast.body[0]) return [stencil_ast.body[0]]
def str_to_pyast(source: str, filename: str = "<unknown>") -> List[ast.FunctionDef]: source_ast = ast.parse(source, filename=filename, type_comments=True) assert isinstance(source_ast, ast.Module) return [ stencil_ast for stencil_ast in source_ast.body if isinstance(stencil_ast, ast.FunctionDef) and Grammar.is_stencil(stencil_ast) ]
def iter_stencils(module: ast.Module) -> Iterator[ast.AST]: for stmt in module.body: if isinstance(stmt, ast.FunctionDef) and Grammar.is_stencil(stmt): yield stmt