示例#1
0
    def run(self, cpp_parser: parser.CPPParser, writer: opgenwriter.SourceWriter):
        self._write_file_prelude(writer)

        generated_funcs = []
        current_ns = None

        for mapped_func in self._parse_mapped_function_decls(cpp_parser):
            del self._mapped_ops[mapped_func.mapped_op_name]
            generated_funcs.append(mapped_func)

            ns = mapped_func.op_namespace
            if current_ns and current_ns != ns:
                current_ns = None
                writer.pop_namespace()
            if ns != current_ns:
                current_ns = ns
                writer.writeline()
                writer.push_namespace(ns)

            writer.writeline()

            self._write_function_signature(writer, mapped_func.cpp_func)
            if mapped_func.signature_only:
                writer.writeline(";")
            else:
                writer.writeline(" {")
                writer.push_indent()
                self._write_function_body(writer, mapped_func)
                writer.pop_indent()
                writer.writeline("}")

        if current_ns:
            current_ns = None
            writer.pop_namespace()

        if not self._custom_ops:
            self._write_function_registrations(writer, generated_funcs)
        else:
            self._write_custom_ops_registrations(writer, generated_funcs)
        self._write_file_postlude(writer)

        if len(self._mapped_ops) > 0:
            raise Exception(
                "Torch operation(s) could not be parsed for mapping: "
                + ", ".join([f"'{o}'" for o in self._mapped_ops.keys()])
            )
示例#2
0
 def _write_file_prelude(self, writer: opgenwriter.SourceWriter):
     writer.writeline("// AUTO-GENERATED CODE! - DO NOT EDIT!")
     writer.writeline(f'// $ python {" ".join(sys.argv)}')
     writer.writeline()
     writer.writeline('#include "python/onnxruntime_pybind_state_common.h"')
     writer.writeline()
     writer.writeline("#include <torch/extension.h>")
     writer.writeline("#include <ATen/native/CPUFallback.h>")
     writer.writeline()
     writer.writeline("#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>")
     writer.writeline()
     writer.writeline('#include "ort_tensor.h"')
     writer.writeline('#include "ort_aten.h"')
     writer.writeline('#include "ort_log.h"')
     writer.writeline()
     writer.push_namespace("torch_ort")
     writer.push_namespace("eager")
     writer.writeline()
     writer.writeline("using namespace at;")
     writer.writeline("using NodeAttributes = onnxruntime::NodeAttributes;")