def gen_py_torch_functions(out, declarations, template_path): """ Generate functions in the "torch" module. """ PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp') py_torch_functions = get_py_torch_functions(declarations) env = create_python_bindings(py_torch_functions, is_python_method=False, module="torch") write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
def gen_py_variable_methods(out, declarations, template_path): """ Generate Tensor methods. """ PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp') py_variable_methods = get_py_variable_methods(declarations) env = create_python_bindings(py_variable_methods, is_python_method=True, module=None) write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
def genCppH(hFilePath, cppFilePath, templateGlslPaths, tmpDirPath, env): print("hFilePath:{}".format(hFilePath)) print("cppFilePath:{}".format(cppFilePath)) h = "#pragma once\n" nsbegin = "\nnamespace at { namespace native { namespace vulkan { \n" nsend = "\n} } } //namespace at::native::vulkan\n" h += nsbegin cpp = "#include <ATen/native/vulkan/{}>".format(H_NAME) cpp += nsbegin for templateGlslPath in templateGlslPaths: name = getName(templateGlslPath) h += "extern const char* " + name + ";\n" cpp += "const char* " + name + " = \n" codeTemplate = CodeTemplate.from_file(templateGlslPath) srcPath = tmpDirPath + "/" + name + ".glsl" content = codeTemplate.substitute(env) lines = content.split("\n") for l in lines: if (len(l) < 1): continue cpp += "\"" + l + "\\n\"\n" cpp += ";\n" cpp += nsend h += nsend with open(hFilePath, "w") as f: f.write(h) with open(cppFilePath, "w") as f: f.write(cpp)
def _read_template(template_fn: str) -> CodeTemplate: return CodeTemplate.from_file(template_fn)
parser.add_argument("--install_dir", default=".", help="where to put generated file") parser.add_argument("--aten_root", default="", help="root directory of aten") args, _ = parser.parse_known_args() if args.aten_root: if not os.path.exists(args.aten_root): raise ValueError('aten_root ({}) does not exist'.format( args.aten_root)) sys.path.append(os.path.join(args.aten_root, '..')) # TODO: fix this from tools.codegen.code_template import CodeTemplate as CT else: from tools.codegen.code_template import CodeTemplate as CT # type: ignore[import,no-redef] OP_TEMPLATE = CT.from_file( os.path.join(args.template_dir, 'aten_op_template.h')) try: # use faster C loader if available from yaml import CLoader as Loader except ImportError: from yaml import Loader # type: ignore[misc] def write(filename, s): with open(filename, "w") as f: f.write(s) def read(filename): with open(filename, "r") as f:
def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): print("hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format( hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath)) vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True) templateSrcPaths = [] for f in vexs: if len(f) > 1: templateSrcPaths.append(f) templateSrcPaths.sort() print("templateSrcPaths:{}".format(templateSrcPaths)) spvPaths = [] for templateSrcPath in templateSrcPaths: print("templateSrcPath {}".format(templateSrcPath)) name = getName(templateSrcPath).replace("_glsl", "") print("name {}".format(name)) codeTemplate = CodeTemplate.from_file(templateSrcPath) srcPath = tmpDirPath + "/" + name + ".glsl" content = codeTemplate.substitute(env) with open(srcPath, 'w') as f: f.write(content) spvPath = tmpDirPath + "/" + name + ".spv" print("spvPath {}".format(spvPath)) cmd = [ glslcPath, "-fshader-stage=compute", srcPath, "-o", spvPath, "--target-env=vulkan1.0", "-Werror" ] print("\nglslc cmd:", cmd) subprocess.check_call(cmd) spvPaths.append(spvPath) h = "#pragma once\n" h += "#include <stdint.h>\n" nsbegin = "\nnamespace at { namespace native { namespace vulkan { \n" nsend = "\n} } } //namespace at::native::vulkan\n" h += nsbegin cpp = "#include <ATen/native/vulkan/{}>".format(H_NAME) cpp += nsbegin for spvPath in spvPaths: name = getName(spvPath) name_len = name + "_len" h += "extern const uint32_t {}[];\n".format(name) h += "extern const uint32_t {};\n".format(name_len) cpp += "const uint32_t " + name + "[] = {\n" sizeBytes = 0 print("spvPath:{}".format(spvPath)) with open(spvPath, 'rb') as f: for word in array.array('I', f.read()): cpp += "{},\n".format(word) sizeBytes += 4 cpp += "};\n" cpp += "const uint32_t {} = {};\n".format(name_len, sizeBytes) cpp += nsend h += nsend with open(hFilePath, "w") as f: f.write(h) with open(cppFilePath, "w") as f: f.write(cpp)