示例#1
0
parser = argparse.ArgumentParser(description="Generate ORT ATen operations")
parser.add_argument(
    "--ops_module", type=str, help="Python module containing the Onnx Operation signature and list of ops to map"
)
parser.add_argument("--output_file", default=None, type=str, help="Output file [default to std out]")
parser.add_argument("--header_file", type=str, help="Header file which contains ATen / Pytorch operation signature")
parser.add_argument(
    "--custom_ops", action="store_true", help="Whether we are generating code for custom ops or native operation"
)

args = parser.parse_args()
loader = SourceFileLoader("", args.ops_module)
ops_module = types.ModuleType(loader.name)
loader.exec_module(ops_module)

ortgen = ORTGen(
    ops_module.ops,
    type_promotion_ops=ops_module.type_promotion_ops,
    custom_ops=args.custom_ops,
    aten_output_type=ops_module.aten_output_type,
)

regdecs_path = args.header_file
print(f"INFO: Using RegistrationDeclarations from: {regdecs_path}")
output = sys.stdout
if args.output_file:
    output = open(args.output_file, "wt")

with CPPParser(regdecs_path) as parser, SourceWriter(output) as writer:
    ortgen.run(parser, writer)
示例#2
0
    def _write_function_registrations(self, writer: opgenwriter.SourceWriter, generated_funcs: List[MappedOpFunction]):
        writer.writeline()
        writer.writeline("TORCH_LIBRARY_IMPL(aten, ORT, m) {")
        writer.push_indent()

        for mapped_func in generated_funcs:
            cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func

            if mapped_func.op_namespace:
                reg_function_arg = f"{mapped_func.op_namespace}::"
            else:
                reg_function_arg = ""
            reg_function_arg += cpp_func.identifier.value

            writer.write("m.impl(")
            reg_function_arg = f"TORCH_FN({reg_function_arg})"

            writer.writeline(f'"{torch_func.identifier.value}", {reg_function_arg});')

        writer.pop_indent()
        writer.writeline("}")
        writer.writeline()
示例#3
0
    def _write_custom_ops_registrations(
        self, writer: opgenwriter.SourceWriter, generated_funcs: List[MappedOpFunction]
    ):
        writer.writeline()
        writer.writeline("TORCH_LIBRARY(ort, m) {")
        writer.push_indent()

        for mapped_func in generated_funcs:
            cpp_func = mapped_func.cpp_func
            writer.write("m.def(")
            writer.writeline(f'"{cpp_func.identifier.value}", &{cpp_func.identifier.value});')

        writer.pop_indent()
        writer.writeline("}")
        writer.writeline()
示例#4
0
    def _write_cpu_fall_back(self, writer: opgenwriter.SourceWriter, mapped_func: MappedOpFunction):
        onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func
        # return at::native::call_fallback_fn<
        #  &at::native::cpu_fallback,
        #  ATEN_OP(eq_Tensor)>::call(self, other);
        writer.writeline("return native::call_fallback_fn<")
        writer.push_indent()
        writer.writeline("&native::cpu_fallback,")
        writer.write("ATEN_OP(")
        writer.write(cpp_func.identifier.value)
        writer.write(")>::call(")

        params = ", ".join([p.member.identifier.value for p in cpp_func.parameters if p.member.identifier])
        writer.write(params)
        writer.writeline(");")
        writer.pop_indent()
示例#5
0
 def _write_function_signature(self, writer: opgenwriter.SourceWriter, cpp_func: ast.FunctionDecl):
     if cpp_func.torch_func:
         writer.writeline(f"// {cpp_func.torch_func.torch_schema}")
     cpp_func.return_type.write(writer)
     writer.write(f" {cpp_func.identifier.value}(")
     writer.push_indent()
     for param_list_member in cpp_func.parameters:
         writer.writeline()
         if isinstance(param_list_member.member.parameter_type, ast.KWArgsSentinelType):
             writer.write("// ")
         param_list_member.write(writer)
     writer.pop_indent()
     writer.write(")")
示例#6
0
 def _write_file_postlude(self, writer: opgenwriter.SourceWriter):
     writer.pop_namespaces()
示例#7
0
 def _write_file_prelude(self, writer: opgenwriter.SourceWriter):
     writer.writeline("// AUTO-GENERATED CODE! - DO NOT EDIT!")
     writer.writeline(f'// $ python {" ".join(sys.argv)}')
     writer.writeline()
     writer.writeline('#include "python/onnxruntime_pybind_state_common.h"')
     writer.writeline()
     writer.writeline("#include <torch/extension.h>")
     writer.writeline("#include <ATen/native/CPUFallback.h>")
     writer.writeline()
     writer.writeline("#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>")
     writer.writeline()
     writer.writeline('#include "ort_tensor.h"')
     writer.writeline('#include "ort_aten.h"')
     writer.writeline('#include "ort_log.h"')
     writer.writeline()
     writer.push_namespace("torch_ort")
     writer.push_namespace("eager")
     writer.writeline()
     writer.writeline("using namespace at;")
     writer.writeline("using NodeAttributes = onnxruntime::NodeAttributes;")
示例#8
0
    def run(self, cpp_parser: parser.CPPParser, writer: opgenwriter.SourceWriter):
        self._write_file_prelude(writer)

        generated_funcs = []
        current_ns = None

        for mapped_func in self._parse_mapped_function_decls(cpp_parser):
            del self._mapped_ops[mapped_func.mapped_op_name]
            generated_funcs.append(mapped_func)

            ns = mapped_func.op_namespace
            if current_ns and current_ns != ns:
                current_ns = None
                writer.pop_namespace()
            if ns != current_ns:
                current_ns = ns
                writer.writeline()
                writer.push_namespace(ns)

            writer.writeline()

            self._write_function_signature(writer, mapped_func.cpp_func)
            if mapped_func.signature_only:
                writer.writeline(";")
            else:
                writer.writeline(" {")
                writer.push_indent()
                self._write_function_body(writer, mapped_func)
                writer.pop_indent()
                writer.writeline("}")

        if current_ns:
            current_ns = None
            writer.pop_namespace()

        if not self._custom_ops:
            self._write_function_registrations(writer, generated_funcs)
        else:
            self._write_custom_ops_registrations(writer, generated_funcs)
        self._write_file_postlude(writer)

        if len(self._mapped_ops) > 0:
            raise Exception(
                "Torch operation(s) could not be parsed for mapping: "
                + ", ".join([f"'{o}'" for o in self._mapped_ops.keys()])
            )