def run(self, cpp_parser: parser.CPPParser, writer: opgenwriter.SourceWriter): self._write_file_prelude(writer) generated_funcs = [] current_ns = None for mapped_func in self._parse_mapped_function_decls(cpp_parser): del self._mapped_ops[mapped_func.mapped_op_name] generated_funcs.append(mapped_func) ns = mapped_func.op_namespace if current_ns and current_ns != ns: current_ns = None writer.pop_namespace() if ns != current_ns: current_ns = ns writer.writeline() writer.push_namespace(ns) writer.writeline() self._write_function_signature(writer, mapped_func.cpp_func) if mapped_func.signature_only: writer.writeline(";") else: writer.writeline(" {") writer.push_indent() self._write_function_body(writer, mapped_func) writer.pop_indent() writer.writeline("}") if current_ns: current_ns = None writer.pop_namespace() if not self._custom_ops: self._write_function_registrations(writer, generated_funcs) else: self._write_custom_ops_registrations(writer, generated_funcs) self._write_file_postlude(writer) if len(self._mapped_ops) > 0: raise Exception( "Torch operation(s) could not be parsed for mapping: " + ", ".join([f"'{o}'" for o in self._mapped_ops.keys()]) )
def _write_file_prelude(self, writer: opgenwriter.SourceWriter): writer.writeline("// AUTO-GENERATED CODE! - DO NOT EDIT!") writer.writeline(f'// $ python {" ".join(sys.argv)}') writer.writeline() writer.writeline('#include "python/onnxruntime_pybind_state_common.h"') writer.writeline() writer.writeline("#include <torch/extension.h>") writer.writeline("#include <ATen/native/CPUFallback.h>") writer.writeline() writer.writeline("#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>") writer.writeline() writer.writeline('#include "ort_tensor.h"') writer.writeline('#include "ort_aten.h"') writer.writeline('#include "ort_log.h"') writer.writeline() writer.push_namespace("torch_ort") writer.push_namespace("eager") writer.writeline() writer.writeline("using namespace at;") writer.writeline("using NodeAttributes = onnxruntime::NodeAttributes;")