Exemple #1
0
    def run(self, cpp_parser: parser.CPPParser, writer: writer.SourceWriter):
        self._write_file_prelude(writer)

        generated_funcs = []
        current_ns = None

        for mapped_func in self._parse_mapped_function_decls(cpp_parser):
            del self._mapped_ops[mapped_func.mapped_op_name]
            generated_funcs.append(mapped_func)

            if mapped_func.make_fallthrough:
                continue

            ns = mapped_func.op_namespace
            if current_ns and current_ns != ns:
                current_ns = None
                writer.pop_namespace()
            if ns != current_ns:
                current_ns = ns
                writer.writeline()
                writer.push_namespace(ns)

            writer.writeline()
            if mapped_func.cpp_func.torch_func:
                writer.writeline(
                    f'// {mapped_func.cpp_func.torch_func.torch_schema}')

            self._write_function_signature(writer, mapped_func.cpp_func)
            if mapped_func.signature_only:
                writer.writeline(';')
            else:
                writer.writeline(' {')
                writer.push_indent()
                self._write_function_body(writer, mapped_func)
                writer.pop_indent()
                writer.writeline('}')

        if current_ns:
            current_ns = None
            writer.pop_namespace()

        if not self._custom_ops:
            self._write_function_registrations(writer, generated_funcs)
        else:
            self._write_custom_ops_registrations(writer, generated_funcs)
        self._write_file_postlude(writer)

        if len(self._mapped_ops) > 0:
            raise Exception('Torch operation(s) could not be parsed for mapping: ' + \
              ', '.join([f'\'{o}\'' for o in self._mapped_ops.keys()]))
Exemple #2
0
 def _write_file_prelude(self, writer: writer.SourceWriter):
     writer.writeline('// AUTO-GENERATED CODE! - DO NOT EDIT!')
     writer.writeline(f'// $ python {" ".join(sys.argv)}')
     writer.writeline()
     writer.writeline('#include <torch/extension.h>')
     writer.writeline()
     writer.writeline(
         '#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>')
     writer.writeline()
     writer.writeline('#include "ort_tensor.h"')
     writer.writeline('#include "ort_aten.h"')
     writer.writeline('#include "ort_log.h"')
     writer.writeline()
     writer.push_namespace('torch_ort')
     writer.push_namespace('eager')
     writer.writeline()
     writer.writeline('using namespace at;')
     writer.writeline('using NodeAttributes = onnxruntime::NodeAttributes;')
Exemple #3
0
 def _write_file_prelude(self, writer: writer.SourceWriter):
     writer.writeline("// AUTO-GENERATED CODE! - DO NOT EDIT!")
     writer.writeline(f'// $ python {" ".join(sys.argv)}')
     writer.writeline()
     writer.writeline('#include "python/onnxruntime_pybind_state_common.h"')
     writer.writeline()
     writer.writeline("#include <torch/extension.h>")
     writer.writeline("#include <ATen/native/CPUFallback.h>")
     writer.writeline()
     writer.writeline(
         "#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>")
     writer.writeline()
     writer.writeline('#include "ort_tensor.h"')
     writer.writeline('#include "ort_aten.h"')
     writer.writeline('#include "ort_log.h"')
     writer.writeline()
     writer.push_namespace("torch_ort")
     writer.push_namespace("eager")
     writer.writeline()
     writer.writeline("using namespace at;")
     writer.writeline("using NodeAttributes = onnxruntime::NodeAttributes;")