parser = argparse.ArgumentParser(description="Generate ORT ATen operations") parser.add_argument( "--ops_module", type=str, help="Python module containing the Onnx Operation signature and list of ops to map" ) parser.add_argument("--output_file", default=None, type=str, help="Output file [default to std out]") parser.add_argument("--header_file", type=str, help="Header file which contains ATen / Pytorch operation signature") parser.add_argument( "--custom_ops", action="store_true", help="Whether we are generating code for custom ops or native operation" ) args = parser.parse_args() loader = SourceFileLoader("", args.ops_module) ops_module = types.ModuleType(loader.name) loader.exec_module(ops_module) ortgen = ORTGen( ops_module.ops, type_promotion_ops=ops_module.type_promotion_ops, custom_ops=args.custom_ops, aten_output_type=ops_module.aten_output_type, ) regdecs_path = args.header_file print(f"INFO: Using RegistrationDeclarations from: {regdecs_path}") output = sys.stdout if args.output_file: output = open(args.output_file, "wt") with CPPParser(regdecs_path) as parser, SourceWriter(output) as writer: ortgen.run(parser, writer)
def _write_function_registrations(self, writer: opgenwriter.SourceWriter, generated_funcs: List[MappedOpFunction]): writer.writeline() writer.writeline("TORCH_LIBRARY_IMPL(aten, ORT, m) {") writer.push_indent() for mapped_func in generated_funcs: cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func if mapped_func.op_namespace: reg_function_arg = f"{mapped_func.op_namespace}::" else: reg_function_arg = "" reg_function_arg += cpp_func.identifier.value writer.write("m.impl(") reg_function_arg = f"TORCH_FN({reg_function_arg})" writer.writeline(f'"{torch_func.identifier.value}", {reg_function_arg});') writer.pop_indent() writer.writeline("}") writer.writeline()
def _write_custom_ops_registrations( self, writer: opgenwriter.SourceWriter, generated_funcs: List[MappedOpFunction] ): writer.writeline() writer.writeline("TORCH_LIBRARY(ort, m) {") writer.push_indent() for mapped_func in generated_funcs: cpp_func = mapped_func.cpp_func writer.write("m.def(") writer.writeline(f'"{cpp_func.identifier.value}", &{cpp_func.identifier.value});') writer.pop_indent() writer.writeline("}") writer.writeline()
def _write_cpu_fall_back(self, writer: opgenwriter.SourceWriter, mapped_func: MappedOpFunction): onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func # return at::native::call_fallback_fn< # &at::native::cpu_fallback, # ATEN_OP(eq_Tensor)>::call(self, other); writer.writeline("return native::call_fallback_fn<") writer.push_indent() writer.writeline("&native::cpu_fallback,") writer.write("ATEN_OP(") writer.write(cpp_func.identifier.value) writer.write(")>::call(") params = ", ".join([p.member.identifier.value for p in cpp_func.parameters if p.member.identifier]) writer.write(params) writer.writeline(");") writer.pop_indent()
def _write_function_signature(self, writer: opgenwriter.SourceWriter, cpp_func: ast.FunctionDecl): if cpp_func.torch_func: writer.writeline(f"// {cpp_func.torch_func.torch_schema}") cpp_func.return_type.write(writer) writer.write(f" {cpp_func.identifier.value}(") writer.push_indent() for param_list_member in cpp_func.parameters: writer.writeline() if isinstance(param_list_member.member.parameter_type, ast.KWArgsSentinelType): writer.write("// ") param_list_member.write(writer) writer.pop_indent() writer.write(")")
def _write_file_postlude(self, writer: opgenwriter.SourceWriter): writer.pop_namespaces()
def _write_file_prelude(self, writer: opgenwriter.SourceWriter): writer.writeline("// AUTO-GENERATED CODE! - DO NOT EDIT!") writer.writeline(f'// $ python {" ".join(sys.argv)}') writer.writeline() writer.writeline('#include "python/onnxruntime_pybind_state_common.h"') writer.writeline() writer.writeline("#include <torch/extension.h>") writer.writeline("#include <ATen/native/CPUFallback.h>") writer.writeline() writer.writeline("#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>") writer.writeline() writer.writeline('#include "ort_tensor.h"') writer.writeline('#include "ort_aten.h"') writer.writeline('#include "ort_log.h"') writer.writeline() writer.push_namespace("torch_ort") writer.push_namespace("eager") writer.writeline() writer.writeline("using namespace at;") writer.writeline("using NodeAttributes = onnxruntime::NodeAttributes;")
def run(self, cpp_parser: parser.CPPParser, writer: opgenwriter.SourceWriter): self._write_file_prelude(writer) generated_funcs = [] current_ns = None for mapped_func in self._parse_mapped_function_decls(cpp_parser): del self._mapped_ops[mapped_func.mapped_op_name] generated_funcs.append(mapped_func) ns = mapped_func.op_namespace if current_ns and current_ns != ns: current_ns = None writer.pop_namespace() if ns != current_ns: current_ns = ns writer.writeline() writer.push_namespace(ns) writer.writeline() self._write_function_signature(writer, mapped_func.cpp_func) if mapped_func.signature_only: writer.writeline(";") else: writer.writeline(" {") writer.push_indent() self._write_function_body(writer, mapped_func) writer.pop_indent() writer.writeline("}") if current_ns: current_ns = None writer.pop_namespace() if not self._custom_ops: self._write_function_registrations(writer, generated_funcs) else: self._write_custom_ops_registrations(writer, generated_funcs) self._write_file_postlude(writer) if len(self._mapped_ops) > 0: raise Exception( "Torch operation(s) could not be parsed for mapping: " + ", ".join([f"'{o}'" for o in self._mapped_ops.keys()]) )