Exemple #1
0
  def _write_function_registrations(
    self,
    writer: writer.SourceWriter,
    generated_funcs: List[MappedOpFunction]):
    writer.writeline()
    writer.writeline('TORCH_LIBRARY_IMPL(aten, ORT, m) {')
    writer.push_indent()

    for mapped_func in generated_funcs:
      cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func

      if mapped_func.make_fallthrough:
        reg_function_arg = 'torch::CppFunction::makeFallthrough()'
      else:
        if mapped_func.op_namespace:
          reg_function_arg = f'{mapped_func.op_namespace}::'
        else:
          reg_function_arg = ''
        reg_function_arg += cpp_func.identifier.value

      writer.write('m.impl(')
      if not mapped_func.make_fallthrough:
        reg_function_arg = f'TORCH_FN({reg_function_arg})'

      writer.writeline(f'"{torch_func.identifier.value}", {reg_function_arg});')

    writer.pop_indent()
    writer.writeline('}')
    writer.writeline()
Exemple #2
0
 def _write_function_signature(self, writer: writer.SourceWriter,
                               cpp_func: ast.FunctionDecl):
     cpp_func.return_type.write(writer)
     writer.write(f" {cpp_func.identifier.value}(")
     writer.push_indent()
     for param_list_member in cpp_func.parameters:
         writer.writeline()
         if isinstance(param_list_member.member.parameter_type,
                       ast.KWArgsSentinelType):
             writer.write("// ")
         param_list_member.write(writer)
     writer.pop_indent()
     writer.write(")")
Exemple #3
0
    def _write_function_registrations(self, writer: writer.SourceWriter,
                                      generated_funcs: List[MappedOpFunction]):
        writer.writeline()
        writer.writeline("TORCH_LIBRARY_IMPL(aten, ORT, m) {")
        writer.push_indent()

        for mapped_func in generated_funcs:
            cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func

            if mapped_func.op_namespace:
                reg_function_arg = f"{mapped_func.op_namespace}::"
            else:
                reg_function_arg = ""
            reg_function_arg += cpp_func.identifier.value

            writer.write("m.impl(")
            reg_function_arg = f"TORCH_FN({reg_function_arg})"

            writer.writeline(
                f'"{torch_func.identifier.value}", {reg_function_arg});')

        writer.pop_indent()
        writer.writeline("}")
        writer.writeline()
Exemple #4
0
    def _write_cpu_fall_back(self, writer: writer.SourceWriter,
                             mapped_func: MappedOpFunction):
        onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func
        #return at::native::call_fallback_fn<
        #  &at::native::cpu_fallback,
        #  ATEN_OP(eq_Tensor)>::call(self, other);
        writer.writeline('return native::call_fallback_fn<')
        writer.push_indent()
        writer.writeline('&native::cpu_fallback,')
        writer.write('ATEN_OP(')
        writer.write(cpp_func.identifier.value)
        writer.write(')>::call(')

        params = ', '.join([p.member.identifier.value for p \
          in cpp_func.parameters if p.member.identifier])
        writer.write(params)
        writer.writeline(');')
        writer.pop_indent()
Exemple #5
0
    def _write_custom_ops_registrations(
            self, writer: writer.SourceWriter,
            generated_funcs: List[MappedOpFunction]):
        writer.writeline()
        writer.writeline(
            "void GenerateCustomOpsBindings(pybind11::module_ m) {")
        writer.push_indent()
        writer.writeline('ORT_LOG_INFO << "GenerateCustomOpsBindings init";')

        for mapped_func in generated_funcs:
            cpp_func = mapped_func.cpp_func
            writer.write("m.def(")
            writer.writeline(
                f'"{cpp_func.identifier.value}", &{cpp_func.identifier.value});'
            )

        writer.pop_indent()
        writer.writeline("}")
        writer.writeline()
Exemple #6
0
    def _write_function_body(self, writer: writer.SourceWriter,
                             mapped_func: MappedOpFunction):
        onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func

        assert len(cpp_func.parameters) > 0

        # Debug Logging
        log_params = ", ".join([
            p.member.identifier.value for p in cpp_func.parameters
            if p.member.identifier
        ])
        writer.writeline(f"ORT_LOG_FN({log_params});")
        writer.writeline()

        if mapped_func.make_torch_fallback:
            return self._write_cpu_fall_back(writer, mapped_func)

        # Eval the outer ONNX op to produce a topologically ordered list of ops
        ctx = ONNXOpEvalContext()
        onnx_op.eval(ctx)
        ctx.prepare_outputs()

        # Fetch the ORT invoker from an at::Tensor.device()
        # FIXME: find the first at::Tensor param anywhere in the signature
        # instead of simply the first parameter?
        first_param = cpp_func.parameters[0].member
        # Check if the first parameter is tensorlist and if yes it's size should be > 0
        if first_param.parameter_type.desugar(
        ).identifier_tokens[0].value == "TensorList":
            writer.write("assert(")
            writer.write(first_param.identifier.value)
            writer.writeline(".size()>0);")

        # generate the type check
        need_type_check = False
        if not self._custom_ops:
            for onnx_op_index, onnx_op in enumerate(ctx.ops):
                for op_input in onnx_op.inputs:
                    if not isinstance(op_input, Outputs):
                        need_type_check = True
                        break
        if need_type_check:
            writer.write("if (")
            i = 0
            for onnx_op_index, onnx_op in enumerate(ctx.ops):
                for idx, op_input in enumerate(onnx_op.inputs):
                    if isinstance(op_input, Outputs):
                        continue
                    writer.writeline(" || " if i > 0 else "")
                    if i == 0:
                        writer.push_indent()
                    cpp_param = cpp_func.get_parameter(op_input)
                    supported_types = ",".join(
                        [type for type in onnx_op.input_types[idx]])
                    writer.write("!IsSupportedType(%s, {%s})" %
                                 (cpp_param.identifier.value, supported_types))
                    i += 1
            writer.writeline(") {")
            self._write_cpu_fall_back(writer, mapped_func)
            writer.pop_indent()
            writer.writeline("}")

        if (not isinstance(first_param.parameter_type.desugar(),
                           ast.ConcreteType) or "Tensor" not in first_param.
                parameter_type.desugar().identifier_tokens[0].value):
            raise FunctionGenerationError(
                cpp_func, "First parameter must be an at::Tensor")

        writer.write("auto& invoker = GetORTInvoker(")
        writer.write(first_param.identifier.value)
        if first_param.parameter_type.desugar(
        ).identifier_tokens[0].value == "TensorList":
            writer.write("[0]")
        writer.writeline(".device());")
        writer.writeline()

        # FIXME: warn if we have not consumed all torch parameters (either as
        # an ORT input or ORT attribute).

        # Perform kernel fission on the ATen op to yield a chain of ORT Invokes
        # e.g. aten::add(x, y, α) -> onnx::Add(x, onnx::Mul(α, y))

        # whether need type promotion
        need_type_promotion = False
        if mapped_func.mapped_op_name in self.type_promotion_ops:
            types_from_tensor = []
            types_from_scalar = []
            for onnx_op_index, onnx_op in enumerate(ctx.ops):
                for op_input in onnx_op.inputs:
                    if isinstance(op_input, Outputs):
                        continue
                cpp_param = cpp_func.get_parameter(op_input)
                if cpp_param:
                    if cpp_param.parameter_type.desugar(
                    ).identifier_tokens[0].value == "Tensor":
                        types_from_tensor.append(f"{op_input}.scalar_type()")
                    elif cpp_param.parameter_type.desugar(
                    ).identifier_tokens[0].value == "Scalar":
                        types_from_scalar.append(f"{op_input}.type()")
            if len(types_from_tensor) > 0 or len(types_from_scalar) > 0:
                need_type_promotion = True
                writer.writeline(
                    "auto promoted_type = PromoteScalarTypesWithCategory({%s}, {%s});"
                    %
                    (",".join(types_from_tensor), ",".join(types_from_scalar)))
                writer.writeline()

        for onnx_op_index, onnx_op in enumerate(ctx.ops):
            # Torch -> ORT inputs
            for op_input in onnx_op.inputs:
                if isinstance(op_input, Outputs):
                    continue
                cpp_param = cpp_func.get_parameter(op_input)
                writer.write(f"auto ort_input_{op_input} = ")
                writer.writeline(f"create_ort_value(invoker, {op_input});")
                if need_type_promotion:
                    type_func_str = ("type()" if cpp_param.parameter_type.
                                     desugar().identifier_tokens[0].value
                                     == "Scalar" else "scalar_type()")
                    writer.write(
                        f"if ({op_input}.{type_func_str} != *promoted_type)")
                    writer.writeline("{")
                    writer.push_indent()
                    writer.writeline(
                        f"ort_input_{op_input} = CastToType(invoker, ort_input_{op_input}, *promoted_type);"
                    )
                    writer.pop_indent()
                    writer.writeline("}")

            # Torch kwargs -> ORT attributes
            attrs = {
                k: v
                for k, v in onnx_op.attributes.items() if v and v.value
            }
            if len(attrs) > 0:
                attrs_arg = "attrs"
                writer.writeline()
                writer.writeline(f"NodeAttributes {attrs_arg}({len(attrs)});")

                for attr_name, attr in attrs.items():
                    writer.write(f'{attrs_arg}["{attr_name}"] = ')
                    writer.writeline("create_ort_attribute(")
                    writer.push_indent()
                    writer.write(f'"{attr_name}", {attr.value}')
                    if attr.type.startswith("at::ScalarType::"):
                        writer.write(f", {attr.type}")
                    elif attr.type == AttrType.TENSOR:
                        writer.write(f", true")
                    elif attr.type != AttrType.STRING:
                        raise FunctionGenerationError(
                            cpp_func,
                            f'Unsure how how to map ONNX op "{onnx_op.name}" attribute '
                            +
                            f'"{attr_name}" of type "{attr.type}" to a call to '
                            +
                            "create_ort_attribute. Please teach generator.py.",
                        )
                    writer.writeline(");")
                    writer.pop_indent()
                attrs_arg = f"&{attrs_arg}"
            else:
                attrs_arg = "nullptr"

            # Outputs vector
            writer.writeline()
            writer.write(f"std::vector<OrtValue> {onnx_op.outputs}")
            writer.writeline(f"({onnx_op.outputs.count});")

            return_info = cpp_func.torch_func.return_type if cpp_func.torch_func else None
            in_place_params = {}

            if return_info:
                for input_index, op_input in enumerate(onnx_op.inputs):
                    if isinstance(op_input, Outputs):
                        continue

                    # See if this input is aliased as an in-place tensor
                    cpp_param = cpp_func.get_parameter(op_input)
                    if cpp_param:
                        for torch_p in cpp_param.torch_param:
                            if isinstance(return_info, ast.TupleType):
                                for output_index, output_param in enumerate(
                                        return_info.elements):
                                    assert isinstance(
                                        output_param.member,
                                        ast.TupleMemberType
                                    ), "output_param.member must be of TupleMemberType"
                                    output_alias = self._get_alias_info(
                                        output_param.member.element_type)
                                    if (output_alias
                                            and self._get_alias_info(torch_p)
                                            == output_alias
                                            and output_alias.is_writable):
                                        writer.writeline(
                                            f"{onnx_op.outputs}[{output_index}] = ort_input_{onnx_op.inputs[input_index]};"
                                        )
                                        in_place_params[
                                            output_index] = cpp_param.identifier.value
                                        break
                            else:
                                output_alias = self._get_alias_info(
                                    return_info)
                                if (output_alias
                                        and self._get_alias_info(torch_p)
                                        == output_alias
                                        and output_alias.is_writable):
                                    writer.writeline(
                                        f"{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[input_index]};"
                                    )
                                    in_place_params[
                                        0] = cpp_param.identifier.value
                                    break

                if len(in_place_params) != 0 and len(in_place_params) != (len(
                        return_info.elements) if isinstance(
                            return_info, ast.TupleType) else 1):
                    raise Exception(
                        f"Cannot mix and match inplace with non-inplace parameters - function: {cpp_func.identifier.value} "
                        +
                        f"in_place_params={in_place_params}, return_elements={return_info.elements}"
                    )

            # Perform the invocation
            writer.writeline()
            if onnx_op_index == 0:
                writer.write("auto ")
            writer.writeline(f'status = invoker.Invoke("{onnx_op.name}", {{')
            writer.push_indent()
            for op_input in onnx_op.inputs:
                if isinstance(op_input, Outputs):
                    if op_input.count != 1:
                        raise FunctionGenerationError(
                            cpp_func, "multiple outputs not supported")
                    op_input = f"{op_input}[0]"
                else:
                    op_input = f"ort_input_{op_input}"
                writer.writeline(f"std::move({op_input}),")
            writer.pop_indent()
            writer.write(f"}}, {onnx_op.outputs}, {attrs_arg}")
            if onnx_op.domain:
                writer.write(f", {onnx_op.domain}")
            writer.writeline(");")
            writer.writeline()

            # Assert invocation
            writer.writeline("if (!status.IsOK())")
            writer.push_indent()
            writer.writeline("throw std::runtime_error(")
            writer.push_indent()
            writer.writeline(
                '"ORT return failure status:" + status.ErrorMessage());')
            writer.pop_indent()
            writer.pop_indent()
            writer.writeline()

            # We'll potentially return back to Torch from this op
            return_outputs = onnx_op.outputs

        # TODO: Pick the right "out" Torch parameter; do not assume the first one
        # TODO: Handle mutliple results
        # TODO: Assert return type

        if len(in_place_params) == 0:
            # tensor options
            writer.write(
                f"at::TensorOptions tensor_options = {first_param.identifier.value}"
            )
            if first_param.parameter_type.desugar(
            ).identifier_tokens[0].value == "TensorList":
                writer.write("[0]")
            writer.write(".options()")
            if need_type_promotion:
                writer.write(".dtype(*promoted_type)")
            writer.writeline(";")

            writer.writeline("return aten_tensor_from_ort(")
            writer.push_indent()
            if (isinstance(cpp_func.return_type, ast.TemplateType)
                    and cpp_func.return_type.identifier_tokens[-1].value
                    == "std::vector"):
                writer.writeline(f"{return_outputs},")
                writer.writeline("tensor_options);")
            else:
                writer.writeline(f"std::move({return_outputs}[0]),")
                writer.writeline("tensor_options);")
            writer.pop_indent()
            return
        else:
            if len(in_place_params) == 1:
                writer.writeline(f"return {in_place_params[0]};")
            else:
                if not (isinstance(cpp_func.return_type, ast.TemplateType)
                        and cpp_func.return_type.identifier_tokens[-1].value
                        == "std::tuple"):
                    raise Exception(f"")
                tensorRef = "Tensor&," * len(in_place_params)
                tensorRef = tensorRef[:len(tensorRef) - 1]
                writer.write(f"return std::tuple<{tensorRef}>(")
                for index, key in enumerate(sorted(in_place_params)):
                    if index > 0:
                        writer.write(", ")
                    writer.write(in_place_params[key])
                writer.writeline(");")
Exemple #7
0
 def _write_file_prelude(self, writer: writer.SourceWriter):
     writer.writeline("// AUTO-GENERATED CODE! - DO NOT EDIT!")
     writer.writeline(f'// $ python {" ".join(sys.argv)}')
     writer.writeline()
     writer.writeline('#include "python/onnxruntime_pybind_state_common.h"')
     writer.writeline()
     writer.writeline("#include <torch/extension.h>")
     writer.writeline("#include <ATen/native/CPUFallback.h>")
     writer.writeline()
     writer.writeline(
         "#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>")
     writer.writeline()
     writer.writeline('#include "ort_tensor.h"')
     writer.writeline('#include "ort_aten.h"')
     writer.writeline('#include "ort_log.h"')
     writer.writeline()
     writer.push_namespace("torch_ort")
     writer.push_namespace("eager")
     writer.writeline()
     writer.writeline("using namespace at;")
     writer.writeline("using NodeAttributes = onnxruntime::NodeAttributes;")
Exemple #8
0
    def run(self, cpp_parser: parser.CPPParser, writer: writer.SourceWriter):
        self._write_file_prelude(writer)

        generated_funcs = []
        current_ns = None

        for mapped_func in self._parse_mapped_function_decls(cpp_parser):
            del self._mapped_ops[mapped_func.mapped_op_name]
            generated_funcs.append(mapped_func)

            ns = mapped_func.op_namespace
            if current_ns and current_ns != ns:
                current_ns = None
                writer.pop_namespace()
            if ns != current_ns:
                current_ns = ns
                writer.writeline()
                writer.push_namespace(ns)

            writer.writeline()
            if mapped_func.cpp_func.torch_func:
                writer.writeline(
                    f"// {mapped_func.cpp_func.torch_func.torch_schema}")

            self._write_function_signature(writer, mapped_func.cpp_func)
            if mapped_func.signature_only:
                writer.writeline(";")
            else:
                writer.writeline(" {")
                writer.push_indent()
                self._write_function_body(writer, mapped_func)
                writer.pop_indent()
                writer.writeline("}")

        if current_ns:
            current_ns = None
            writer.pop_namespace()

        if not self._custom_ops:
            self._write_function_registrations(writer, generated_funcs)
        else:
            self._write_custom_ops_registrations(writer, generated_funcs)
        self._write_file_postlude(writer)

        if len(self._mapped_ops) > 0:
            raise Exception(
                "Torch operation(s) could not be parsed for mapping: " +
                ", ".join([f"'{o}'" for o in self._mapped_ops.keys()]))
Exemple #9
0
  def _write_function_body(
    self,
    writer: writer.SourceWriter,
    mapped_func: MappedOpFunction):
    onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func

    assert(len(cpp_func.parameters) > 0)

    return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type) if cpp_func.torch_func else None
    if return_alias_info and not return_alias_info.is_writable:
      return_alias_info = None
    in_place_param: ast.ParameterDecl = None

    # Eval the outer ONNX op to produce a topologically ordered list of ops
    ctx = ONNXOpEvalContext()
    onnx_op.eval(ctx)
    ctx.prepare_outputs()

    # Debug Logging
    log_params = ', '.join([p.member.identifier.value for p \
      in cpp_func.parameters if p.member.identifier])
    writer.writeline(f'ORT_LOG_FN({log_params});')
    writer.writeline()

    # Fetch the ORT invoker from an at::Tensor.device()
    # FIXME: find the first at::Tensor param anywhere in the signature
    # instead of simply the first parameter?
    first_param = cpp_func.parameters[0].member
    if not isinstance(
      first_param.parameter_type.desugar(),
      ast.ConcreteType) or 'Tensor' not in first_param.parameter_type.desugar().identifier_tokens[0].value:
      raise FunctionGenerationError(
        cpp_func,
        'First parameter must be an at::Tensor')

    writer.write('auto& invoker = GetORTInvoker(')
    writer.write(first_param.identifier.value)
    writer.writeline('.device());')
    writer.writeline()

    # FIXME: warn if we have not consumed all torch parameters (either as
    # an ORT input or ORT attribute).

    # Perform kernel fission on the ATen op to yield a chain of ORT Invokes
    # e.g. aten::add(x, y, α) -> onnx::Add(x, onnx::Mul(α, y))
    for onnx_op_index, onnx_op in enumerate(ctx.ops):
      # Torch -> ORT inputs
      for op_input in onnx_op.inputs:
        if isinstance(op_input, Outputs):
          continue
        # See if this input is aliased as an in-place tensor
        cpp_param = cpp_func.get_parameter(op_input)
        if return_alias_info and cpp_param and \
          len(cpp_param.torch_param) == 1 and \
          self._get_alias_info(cpp_param.torch_param[0]) == return_alias_info:
          in_place_param = cpp_param

        writer.write(f'auto ort_input_{op_input} = ')
        writer.writeline(f'create_ort_value(invoker, {op_input});')

      # Torch kwargs -> ORT attributes
      attrs = { k:v for k, v in onnx_op.attributes.items() if v and v.value }
      if len(attrs) > 0:
        attrs_arg = 'attrs'
        writer.writeline()
        writer.writeline(f'NodeAttributes {attrs_arg}({len(attrs)});')

        for attr_name, attr in attrs.items():
          writer.write(f'{attrs_arg}["{attr_name}"] = ')
          writer.writeline('create_ort_attribute(')
          writer.push_indent()
          writer.write(f'"{attr_name}", {attr.value}')
          if attr.type.startswith('at::ScalarType::'):
            writer.write(f', {attr.type}')
          elif attr.type != AttrType.STRING:
            raise FunctionGenerationError(
              cpp_func,
              f'Unsure how how to map ONNX op "{onnx_op.name}" attribute ' + 
              f'"{attr_name}" of type "{attr.type}" to a call to ' +
              'create_ort_attribute. Please teach generator.py.')
          writer.writeline(');')
          writer.pop_indent()
        attrs_arg = f'&{attrs_arg}'
      else:
        attrs_arg = 'nullptr'

      # Outputs vector
      writer.writeline()
      writer.write(f'std::vector<OrtValue> {onnx_op.outputs}')
      writer.writeline(f'({onnx_op.outputs.count});')

      # Perform the invocation
      writer.writeline()
      if onnx_op_index == 0:
        writer.write('auto ')
      writer.writeline(f'status = invoker.Invoke("{onnx_op.name}", {{')
      writer.push_indent()
      for op_input in onnx_op.inputs:
        if isinstance(op_input, Outputs):
          if op_input.count != 1:
            raise FunctionGenerationError(
              cpp_func,
              'multiple outputs not supported')
          op_input = f'{op_input}[0]'
        else:
          op_input = f'ort_input_{op_input}'
        writer.writeline(f'std::move({op_input}),')
      writer.pop_indent()
      writer.write(f'}}, {onnx_op.outputs}, {attrs_arg}')
      if onnx_op.domain:
        writer.write(f', {onnx_op.domain}')
      writer.writeline(');')
      writer.writeline()

      # Assert invocation
      writer.writeline('if (!status.IsOK())')
      writer.push_indent()
      writer.writeline('throw std::runtime_error(')
      writer.push_indent()
      writer.writeline('"ORT return failure status:" + status.ErrorMessage());')
      writer.pop_indent()
      writer.pop_indent()
      writer.writeline()

      # We'll potentially return back to Torch from this op
      return_outputs = onnx_op.outputs

    # TODO: Pick the right "out" Torch parameter; do not assume the first one
    # TODO: Handle mutliple results
    # TODO: Assert return type

    if not return_alias_info:     
      writer.writeline('return aten_tensor_from_ort(')
      writer.push_indent()
      if isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::vector':
        writer.writeline(f'{return_outputs},')
        writer.writeline(f'{first_param.identifier.value}.options());')
      else:
        writer.writeline(f'std::move({return_outputs}[0]),')
        writer.writeline(f'{first_param.identifier.value}.options());')
      writer.pop_indent()
      return

    if not in_place_param:
      raise Exception(f'"{cpp_func.torch_func.torch_schema}" ' +
        'has alias info on its return type but no associated parameter')

    writer.writeline(f'copy(invoker, {return_outputs}[0], ort_input_{in_place_param.identifier.value});')
    writer.writeline(f'return {in_place_param.identifier.value};')
Exemple #10
0
 def _write_file_prelude(self, writer: writer.SourceWriter):
     writer.writeline('// AUTO-GENERATED CODE! - DO NOT EDIT!')
     writer.writeline(f'// $ python {" ".join(sys.argv)}')
     writer.writeline()
     writer.writeline('#include <torch/extension.h>')
     writer.writeline()
     writer.writeline(
         '#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>')
     writer.writeline()
     writer.writeline('#include "ort_tensor.h"')
     writer.writeline('#include "ort_aten.h"')
     writer.writeline('#include "ort_log.h"')
     writer.writeline()
     writer.push_namespace('torch_ort')
     writer.push_namespace('eager')
     writer.writeline()
     writer.writeline('using namespace at;')
     writer.writeline('using NodeAttributes = onnxruntime::NodeAttributes;')
Exemple #11
0
    def run(self, cpp_parser: parser.CPPParser, writer: writer.SourceWriter):
        self._write_file_prelude(writer)

        generated_funcs = []
        current_ns = None

        for mapped_func in self._parse_mapped_function_decls(cpp_parser):
            del self._mapped_ops[mapped_func.torch_func.identifier.value]
            generated_funcs.append(mapped_func)

            if mapped_func.make_fallthrough:
                continue

            ns = mapped_func.torch_op_namespace
            if current_ns and current_ns != ns:
                current_ns = None
                writer.pop_namespace()
            if ns != current_ns:
                current_ns = ns
                writer.writeline()
                writer.push_namespace(ns)

            writer.writeline()
            writer.writeline(f'// {mapped_func.torch_func.torch_schema}')

            self._write_function_signature(writer, mapped_func.cpp_func)
            if mapped_func.signature_only:
                writer.writeline(';')
            else:
                writer.writeline(' {')
                writer.push_indent()
                self._write_function_body(writer, mapped_func)
                writer.pop_indent()
                writer.writeline('}')

        if current_ns:
            current_ns = None
            writer.pop_namespace()

        self._write_function_registrations(writer, generated_funcs)
        self._write_file_postlude(writer)

        if len(self._mapped_ops) > 0:
            raise Exception('Torch operation(s) could not be parsed for mapping: ' + \
              ', '.join([f'\'{o}\'' for o in self._mapped_ops.keys()]))
Exemple #12
0
  def _write_function_body(
    self,
    writer: writer.SourceWriter,
    mapped_func: MappedOpFunction):
    onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func

    assert(len(cpp_func.parameters) > 0)

    # Debug Logging
    log_params = ', '.join([p.member.identifier.value for p \
      in cpp_func.parameters if p.member.identifier])
    writer.writeline(f'ORT_LOG_FN({log_params});')
    writer.writeline()

    if mapped_func.make_torch_fallback:
      return self._write_cpu_fall_back(writer, mapped_func)

    return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type) if cpp_func.torch_func else None
    if return_alias_info and not return_alias_info.is_writable:
      return_alias_info = None
    in_place_param: ast.ParameterDecl = None

    # Eval the outer ONNX op to produce a topologically ordered list of ops
    ctx = ONNXOpEvalContext()
    onnx_op.eval(ctx)
    ctx.prepare_outputs()

    # Fetch the ORT invoker from an at::Tensor.device()
    # FIXME: find the first at::Tensor param anywhere in the signature
    # instead of simply the first parameter?
    first_param = cpp_func.parameters[0].member
    # Check if the first parameter is tensorlist and if yes it's size should be > 0
    if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList':
      writer.write('assert(')
      writer.write(first_param.identifier.value)
      writer.writeline('.size()>0);')

    # generate the type check
    need_type_check = False
    if not self._custom_ops:
      for onnx_op_index, onnx_op in enumerate(ctx.ops):
        for op_input in onnx_op.inputs:
          if not isinstance(op_input, Outputs):
            need_type_check = True
            break
    if need_type_check:
      writer.write('if (')
      i = 0
      for onnx_op_index, onnx_op in enumerate(ctx.ops):
        for idx, op_input in enumerate(onnx_op.inputs):
          if isinstance(op_input, Outputs):
            continue
          writer.writeline(' || ' if i > 0 else '')
          if i == 0:
            writer.push_indent()
          cpp_param = cpp_func.get_parameter(op_input)
          supported_types = ','.join([type for type in onnx_op.input_types[idx]])
          writer.write('!IsSupportedType(%s, {%s})' % (cpp_param.identifier.value, supported_types))
          i += 1
      writer.writeline(') {')
      self._write_cpu_fall_back(writer, mapped_func)
      writer.pop_indent()
      writer.writeline('}')      

    if not isinstance(
      first_param.parameter_type.desugar(),
      ast.ConcreteType) or 'Tensor' not in first_param.parameter_type.desugar().identifier_tokens[0].value:
      raise FunctionGenerationError(
        cpp_func,
        'First parameter must be an at::Tensor')

    writer.write('auto& invoker = GetORTInvoker(')
    writer.write(first_param.identifier.value)
    if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList':
      writer.write('[0]')
    writer.writeline('.device());')
    writer.writeline()

    # FIXME: warn if we have not consumed all torch parameters (either as
    # an ORT input or ORT attribute).

    # Perform kernel fission on the ATen op to yield a chain of ORT Invokes
    # e.g. aten::add(x, y, α) -> onnx::Add(x, onnx::Mul(α, y))
    
    # whether need type promotion
    need_type_promotion = False
    if mapped_func.mapped_op_name in self.type_promotion_ops:
      types_from_tensor = []
      types_from_scalar = []
      for onnx_op_index, onnx_op in enumerate(ctx.ops):
        for op_input in onnx_op.inputs:
          if isinstance(op_input, Outputs):
            continue
        cpp_param = cpp_func.get_parameter(op_input)
        if cpp_param:
          if cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Tensor':
            types_from_tensor.append(f'{op_input}.scalar_type()')
          elif cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Scalar':
            types_from_scalar.append(f'{op_input}.type()')
      if len(types_from_tensor) > 0 or len(types_from_scalar) > 0 :
        need_type_promotion = True
        writer.writeline('auto promoted_type = PromoteScalarTypesWithCategory({%s}, {%s});'
                         % (','.join(types_from_tensor), ','.join(types_from_scalar)))
        writer.writeline()

    for onnx_op_index, onnx_op in enumerate(ctx.ops):
      # Torch -> ORT inputs
      for op_input in onnx_op.inputs:
        if isinstance(op_input, Outputs):
          continue
        # See if this input is aliased as an in-place tensor
        cpp_param = cpp_func.get_parameter(op_input)
        if return_alias_info and cpp_param:
          for torch_p in cpp_param.torch_param:
            if self._get_alias_info(torch_p) == return_alias_info:
              in_place_param = cpp_param

        writer.write(f'auto ort_input_{op_input} = ')
        writer.writeline(f'create_ort_value(invoker, {op_input});')
        if need_type_promotion:
          type_func_str = 'type()' if cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Scalar' else 'scalar_type()'
          writer.write(f'if ({op_input}.{type_func_str} != *promoted_type)')
          writer.writeline('{')
          writer.push_indent()
          writer.writeline(f'ort_input_{op_input} = CastToType(invoker, ort_input_{op_input}, *promoted_type);')
          writer.pop_indent()
          writer.writeline('}')

      # Torch kwargs -> ORT attributes
      attrs = { k:v for k, v in onnx_op.attributes.items() if v and v.value }
      if len(attrs) > 0:
        attrs_arg = 'attrs'
        writer.writeline()
        writer.writeline(f'NodeAttributes {attrs_arg}({len(attrs)});')

        for attr_name, attr in attrs.items():
          writer.write(f'{attrs_arg}["{attr_name}"] = ')
          writer.writeline('create_ort_attribute(')
          writer.push_indent()
          writer.write(f'"{attr_name}", {attr.value}')
          if attr.type.startswith('at::ScalarType::'):
            writer.write(f', {attr.type}')
          elif attr.type != AttrType.STRING:
            raise FunctionGenerationError(
              cpp_func,
              f'Unsure how how to map ONNX op "{onnx_op.name}" attribute ' + 
              f'"{attr_name}" of type "{attr.type}" to a call to ' +
              'create_ort_attribute. Please teach generator.py.')
          writer.writeline(');')
          writer.pop_indent()
        attrs_arg = f'&{attrs_arg}'
      else:
        attrs_arg = 'nullptr'

      # Outputs vector
      writer.writeline()
      writer.write(f'std::vector<OrtValue> {onnx_op.outputs}')
      writer.writeline(f'({onnx_op.outputs.count});')

      if in_place_param:
        assert(onnx_op.outputs.count == 1)
        # TODO: This assumes that the first output corresponds to the first input.
        # This may not work for more complicated ops.
        writer.writeline(f'{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[0]};')

      # Perform the invocation
      writer.writeline()
      if onnx_op_index == 0:
        writer.write('auto ')
      writer.writeline(f'status = invoker.Invoke("{onnx_op.name}", {{')
      writer.push_indent()
      for op_input in onnx_op.inputs:
        if isinstance(op_input, Outputs):
          if op_input.count != 1:
            raise FunctionGenerationError(
              cpp_func,
              'multiple outputs not supported')
          op_input = f'{op_input}[0]'
        else:
          op_input = f'ort_input_{op_input}'
        writer.writeline(f'std::move({op_input}),')
      writer.pop_indent()
      writer.write(f'}}, {onnx_op.outputs}, {attrs_arg}')
      if onnx_op.domain:
        writer.write(f', {onnx_op.domain}')
      writer.writeline(');')
      writer.writeline()

      # Assert invocation
      writer.writeline('if (!status.IsOK())')
      writer.push_indent()
      writer.writeline('throw std::runtime_error(')
      writer.push_indent()
      writer.writeline('"ORT return failure status:" + status.ErrorMessage());')
      writer.pop_indent()
      writer.pop_indent()
      writer.writeline()

      # We'll potentially return back to Torch from this op
      return_outputs = onnx_op.outputs

    # TODO: Pick the right "out" Torch parameter; do not assume the first one
    # TODO: Handle mutliple results
    # TODO: Assert return type

    if not return_alias_info:     
      # tensor options
      writer.write(f'at::TensorOptions tensor_options = {first_param.identifier.value}')
      if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList':
        writer.write('[0]')
      writer.write('.options()')
      if need_type_promotion:
        writer.write('.dtype(*promoted_type)')
      writer.writeline(';')

      writer.writeline('return aten_tensor_from_ort(')
      writer.push_indent()
      if isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::vector':
        writer.writeline(f'{return_outputs},')
        writer.writeline('tensor_options);')
      else:
        writer.writeline(f'std::move({return_outputs}[0]),')
        writer.writeline('tensor_options);')
      writer.pop_indent()
      return

    if not in_place_param:
      raise Exception(f'"{cpp_func.torch_func.torch_schema}" ' +
        'has alias info on its return type but no associated parameter')

    writer.writeline(f'return {in_place_param.identifier.value};')