def run(self, cpp_parser: parser.CPPParser, writer: writer.SourceWriter): self._write_file_prelude(writer) generated_funcs = [] current_ns = None for mapped_func in self._parse_mapped_function_decls(cpp_parser): del self._mapped_ops[mapped_func.mapped_op_name] generated_funcs.append(mapped_func) if mapped_func.make_fallthrough: continue ns = mapped_func.op_namespace if current_ns and current_ns != ns: current_ns = None writer.pop_namespace() if ns != current_ns: current_ns = ns writer.writeline() writer.push_namespace(ns) writer.writeline() if mapped_func.cpp_func.torch_func: writer.writeline( f'// {mapped_func.cpp_func.torch_func.torch_schema}') self._write_function_signature(writer, mapped_func.cpp_func) if mapped_func.signature_only: writer.writeline(';') else: writer.writeline(' {') writer.push_indent() self._write_function_body(writer, mapped_func) writer.pop_indent() writer.writeline('}') if current_ns: current_ns = None writer.pop_namespace() if not self._custom_ops: self._write_function_registrations(writer, generated_funcs) else: self._write_custom_ops_registrations(writer, generated_funcs) self._write_file_postlude(writer) if len(self._mapped_ops) > 0: raise Exception('Torch operation(s) could not be parsed for mapping: ' + \ ', '.join([f'\'{o}\'' for o in self._mapped_ops.keys()]))
def _write_file_prelude(self, writer: writer.SourceWriter): writer.writeline('// AUTO-GENERATED CODE! - DO NOT EDIT!') writer.writeline(f'// $ python {" ".join(sys.argv)}') writer.writeline() writer.writeline('#include <torch/extension.h>') writer.writeline() writer.writeline( '#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>') writer.writeline() writer.writeline('#include "ort_tensor.h"') writer.writeline('#include "ort_aten.h"') writer.writeline('#include "ort_log.h"') writer.writeline() writer.push_namespace('torch_ort') writer.push_namespace('eager') writer.writeline() writer.writeline('using namespace at;') writer.writeline('using NodeAttributes = onnxruntime::NodeAttributes;')
def _write_file_prelude(self, writer: writer.SourceWriter): writer.writeline("// AUTO-GENERATED CODE! - DO NOT EDIT!") writer.writeline(f'// $ python {" ".join(sys.argv)}') writer.writeline() writer.writeline('#include "python/onnxruntime_pybind_state_common.h"') writer.writeline() writer.writeline("#include <torch/extension.h>") writer.writeline("#include <ATen/native/CPUFallback.h>") writer.writeline() writer.writeline( "#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>") writer.writeline() writer.writeline('#include "ort_tensor.h"') writer.writeline('#include "ort_aten.h"') writer.writeline('#include "ort_log.h"') writer.writeline() writer.push_namespace("torch_ort") writer.push_namespace("eager") writer.writeline() writer.writeline("using namespace at;") writer.writeline("using NodeAttributes = onnxruntime::NodeAttributes;")