def gen_out_wrapper(g: ExternalBackendFunctionsGroup) -> Optional[str]:
            dispatcher_sig = DispatcherSignature.from_schema(
                g.out.native_function.func)
            name = dispatcher_sig.name()

            dispatcher_order_args = dispatcher.jit_arguments(
                g.out.native_function.func)
            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            ]
            print_args_str = ''.join(
                [f' << " {a.name}=" << {a.name}.toString()' for a in tensors])

            func_name = f'AtenXlaTypeDefault::{name}'
            functional_result_name = f'{name}_tmp'
            return_names = cpp.return_names(g.out.native_function)
            if len(return_names) > 1:
                updates = '\n  '.join(
                    f'bridge::XlaUpdateTensors({{{ret_name}}}, {{std::get<{i}>({functional_result_name})}}, {{0}});'
                    for i, ret_name in enumerate(return_names))
                returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_names)})'
            else:
                ret_name = return_names[0]
                updates = f'bridge::XlaUpdateTensors({{{ret_name}}}, {{{functional_result_name}}}, {{0}});'
                returns = ret_name

            functional_sig = DispatcherSignature.from_schema(
                g.functional.native_function.func)

            return f"""\
예제 #2
0
    def gen_out_inplace_wrapper(self, f: NativeFunction, g: Optional[NativeFunctionsGroup]) -> Optional[str]:
        if g is None:
            return None
        k = f.func.kind()
        if k is SchemaKind.inplace:
            copy_op = 'at::_copy_from'
        elif k is SchemaKind.out:
            copy_op = 'at::_copy_from_and_resize'
        else:
            raise AssertionError("gen_out_inplace_wrapper called on a functional op")

        sig = self.wrapper_kernel_sig(f)
        name = sig.name()

        # See Note [External Backends Follow Dispatcher convention]
        jit_args = dispatcher.jit_arguments(f.func)
        tensors = [a for a in jit_args if isinstance(a, Argument) and a.type == BaseType(BaseTy.Tensor)]
        print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensors])

        func_res = f'{name}_tmp'
        return_names = cpp.return_names(f)
        if len(return_names) > 1:
            updates = '\n  '.join(
                f'{copy_op}(std::get<{i}>({func_res}), {ret_name});'
                for i, ret_name in enumerate(return_names))
            returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
        else:
            ret_name = return_names[0]
            updates = f'{copy_op}({func_res}, {ret_name});'
            returns = ret_name

        functional_sig = self.wrapper_kernel_sig(g.functional)

        return f"""\
        def gen_unstructured_external(
                f: ExternalBackendFunction) -> Optional[str]:
            if not requires_backend_wrapper(f):
                return None

            def get_device_param(args: List[Argument]) -> str:
                # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
                # to use as the device argument.
                # We should update this to be consistent with how we choose device guards.
                const_tensor_or_self = [
                    a for a in args
                    if (a.type == BaseType(BaseTy.Tensor)
                        or a.type == OptionalType(BaseType(BaseTy.Tensor)))
                    and not a.is_write
                ]
                if any(const_tensor_or_self):
                    return const_tensor_or_self[0].name
                tensor_like = [a for a in args if a.type.is_tensor_like()]
                if any(tensor_like):
                    return tensor_like[0].name
                device_like = [
                    a for a in args if a.type == BaseType(BaseTy.Device)
                    or a.type == OptionalType(BaseType(BaseTy.Device))
                ]
                if any(device_like):
                    return device_like[0].name
                raise AssertionError(
                    "Need a tensor-like or device argument in order to determine the output device"
                )

            # XLA appears to have used the dispatcher convention to write their kernel signatures,
            # probably because they based their signatures off of our RegistrationDeclarations.h
            dispatcher_sig = DispatcherSignature.from_schema(
                f.native_function.func)
            name = dispatcher_sig.name()
            args = dispatcher_sig.arguments()

            if self.target is Target.NAMESPACED_DECLARATION:
                return f"  static {dispatcher_sig.decl()};"

            elif self.target is Target.REGISTRATION:
                if f.metadata is not None:
                    # xla has their own kernel: register it
                    namespace = 'AtenXlaType'
                else:
                    # xla doesn't have a kernel: register the cpu fallback (or codegen'd out kernel).
                    namespace = 'AtenXlaTypeDefault'
                payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&{namespace}::{name})"
                return f'  m.impl("{f.native_function.func.name}", {payload});\n'

            if self.target is not Target.NAMESPACED_DEFINITION:
                assert_never(self.target)

            # Instead of generating a CPU fallback, the xla codegen generates out wrappers for a few hardcoded operators.
            # TODO: we should generate out wrappers for ALL valid out kernels; not just ones in xla's hardcoded list
            if f.native_function.func.kind() is SchemaKind.out and str(f.native_function.func.name.name) in _FN_OUT \
                    and isinstance(g, ExternalBackendFunctionsGroup):
                return gen_out_wrapper(g)

            # Everything below here is where we generate the CPU fallback.
            dispatcher_order_args = dispatcher.jit_arguments(
                f.native_function.func)

            # Map each argument to it's intermediate variable name in the fallback
            # We have to do it separately for TensorList/Optional<Tensor>/Tensor
            tensorlist_args: Dict[Argument, str] = {
                a: f'l_{a.name}'
                for a in dispatcher_order_args if isinstance(a.type, ListType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            }

            opt_tensors = [
                a for a in dispatcher_order_args
                if isinstance(a.type, OptionalType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            ]
            opt_tensor_args: Dict[Argument, str] = {
                a: f'xlatens_opt[{i}]'
                for i, a in enumerate(opt_tensors)
            }

            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            ]
            tensor_args: Dict[Argument, str] = {
                a: f'xlatens[{i}]'
                for i, a in enumerate(tensors)
            }
            annotated_tensor_indices: List[int] = [
                i for i, a in enumerate(tensors)
                if a.annotation is not None and a.annotation.is_write
            ]

            print_args_str = ''.join([
                f' << " {a.name}=" << {a.name}.toString()'
                for a in tensor_args.keys()
            ])

            tensorlist_intermediates_str = ''
            if len(tensorlist_args) > 0:
                tensorlist_intermediates_str = '\n'.join([
                    f'  auto {updated_name} = bridge::XlaCreateTensorList({arg.name});'
                    for arg, updated_name in tensorlist_args.items()
                ])

            opt_tensor_intermediates_str = ''
            if len(opt_tensor_args) > 0:
                arg_str = ", ".join([a.name for a in opt_tensor_args.keys()])
                opt_tensor_intermediates_str = f'\n  std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};'
                opt_tensor_intermediates_str += '\n  auto xlatens_opt = bridge::XlaCreateOptTensorList(xlatens_opt_tensors);'

            intermediates = ''
            if tensorlist_intermediates_str != '':
                intermediates += tensorlist_intermediates_str + '\n'
            intermediates += f"  std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};"
            intermediates += "\n  auto xlatens = bridge::XlaCreateTensorList(xlatens_tensors);"
            if opt_tensor_intermediates_str != '':
                intermediates += opt_tensor_intermediates_str

            is_method = Variant.function not in f.native_function.variants
            func_name = f'AtenXlaTypeDefault::{name}'

            # Gather all of the updated variable names to call into the CPU operator.
            # Just use the original binding names for inputs where we didn't create explicit intermediate variables.
            updated_bindings: List[str] = [
                tensorlist_args.get(
                    a, opt_tensor_args.get(a, tensor_args.get(a, a.name)))
                for a in dispatcher_order_args
            ]

            at_call_name = CppSignatureGroup.from_native_function(
                f.native_function,
                method=is_method).most_faithful_signature().name()

            # Notice that we don't need to perform a translate: we're technically going from the dispatcher API
            # to the faithful C++ API, which are carefuly written to be exactly the same.
            cpu_result_name = 'x_result'
            if is_method:
                at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});'
            else:
                at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});'
            avoid_warning = ''
            if f.native_function.func.returns:
                at_call = f'auto&& {cpu_result_name} = {at_call}'
                avoid_warning = f'\n  static_cast<void>({cpu_result_name}); // Avoid warnings in case not used'

            collect_mutated_tensors = ''
            update_tensors = ''
            if len(annotated_tensor_indices) > 0:
                indices_str = ", ".join(
                    [str(i) for i in annotated_tensor_indices])
                collect_mutated_tensors = f'\n  std::vector<size_t> xlatens_update_indices = {{{indices_str}}};'
                update_tensors = '\n  bridge::XlaUpdateTensors(xlatens_tensors, xlatens, xlatens_update_indices);'

            returns = ''
            if f.native_function.func.returns:
                ret_names = cpp.return_names(f.native_function,
                                             fallback_name=cpu_result_name)
                if len(ret_names) == 1:
                    returns = xla_tensor_creation_api(
                        ret_names[0],
                        f.native_function.func.returns[0],
                        get_device_param(dispatcher_order_args),
                        cpu_result_name=cpu_result_name)
                else:
                    return_args = [
                        xla_tensor_creation_api(
                            ret_names[i],
                            f.native_function.func.returns[i],
                            get_device_param(dispatcher_order_args),
                            cpu_result_name=f'std::get<{i}>({cpu_result_name})'
                        ) for i in range(len(f.native_function.func.returns))
                    ]
                    returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})'
            return_str = ''
            if returns != '':
                return_str = f'\n  return {returns};'

            return f"""\
예제 #4
0
        def gen_unstructured_external(f: NativeFunction) -> Optional[str]:
            if not requires_backend_wrapper(f, self.backend_index):
                return None

            def get_device_param(args: List[Argument]) -> str:
                # TODO: the XLA codegen has specific precedence rules when determining which tensor argument
                # to use as the device argument.
                # We should update this to be consistent with how we choose device guards.
                const_tensor_or_self = [
                    a for a in args
                    if (a.type == BaseType(BaseTy.Tensor)
                        or a.type == OptionalType(BaseType(BaseTy.Tensor)))
                    and not a.is_write
                ]
                if any(const_tensor_or_self):
                    return const_tensor_or_self[0].name
                tensor_like = [a for a in args if a.type.is_tensor_like()]
                if any(tensor_like):
                    return tensor_like[0].name
                device_like = [
                    a for a in args if a.type == BaseType(BaseTy.Device)
                    or a.type == OptionalType(BaseType(BaseTy.Device))
                ]
                if any(device_like):
                    return device_like[0].name
                raise AssertionError(
                    "Need a tensor-like or device argument in order to determine the output device"
                )

            # XLA appears to have used the dispatcher convention to write their kernel signatures,
            # probably because they based their signatures off of our RegistrationDeclarations.h
            # See Note [External Backends Follow Dispatcher API]
            dispatcher_sig = DispatcherSignature.from_schema(f.func)
            name = dispatcher_sig.name()
            args = dispatcher_sig.arguments()

            if self.target is Target.NAMESPACED_DECLARATION:
                return f"  static {dispatcher_sig.decl()};"

            elif self.target is Target.REGISTRATION:
                # This codegen is only responsible for registering CPU fallback kernels
                # We also skip registrations if there is a functional backend kernel,
                # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py).
                if self.backend_index.get_kernel(f) is not None or \
                        (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)):
                    return ''
                payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})"
                return f'  m.impl("{f.func.name}", {payload});\n'

            if self.target is not Target.NAMESPACED_DEFINITION:
                assert_never(self.target)

            # Everything below here is where we generate the CPU fallback.
            dispatcher_order_args = dispatcher.jit_arguments(f.func)

            # Map each argument to it's intermediate variable name in the fallback
            # We have to do it separately for TensorList/Optional<Tensor>/Tensor
            tensorlist_args: Dict[Argument, str] = {
                a: f'l_{a.name}'
                for a in dispatcher_order_args if isinstance(a.type, ListType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            }

            opt_tensors = [
                a for a in dispatcher_order_args
                if isinstance(a.type, OptionalType)
                and a.type.elem == BaseType(BaseTy.Tensor)
            ]
            opt_tensor_args: Dict[Argument, str] = {
                a: f'xlatens_opt[{i}]'
                for i, a in enumerate(opt_tensors)
            }

            tensors = [
                a for a in dispatcher_order_args
                if a.type == BaseType(BaseTy.Tensor)
            ]
            tensor_args: Dict[Argument, str] = {
                a: f'xlatens[{i}]'
                for i, a in enumerate(tensors)
            }
            annotated_tensor_indices: List[int] = [
                i for i, a in enumerate(tensors)
                if a.annotation is not None and a.annotation.is_write
            ]

            print_args_str = ''.join([
                f' << " {a.name}=" << {a.name}.toString()'
                for a in tensor_args.keys()
            ])

            tensorlist_intermediates_str = ''
            if len(tensorlist_args) > 0:
                tensorlist_intermediates_str = '\n'.join([
                    f'  auto {updated_name} = to_cpu({arg.name});'
                    for arg, updated_name in tensorlist_args.items()
                ])

            opt_tensor_intermediates_str = ''
            if len(opt_tensor_args) > 0:
                arg_str = ", ".join([a.name for a in opt_tensor_args.keys()])
                opt_tensor_intermediates_str = f'\n  std::vector<c10::optional<at::Tensor>> xlatens_opt_tensors = {{{arg_str}}};'
                opt_tensor_intermediates_str += '\n  auto xlatens_opt = to_cpu(xlatens_opt_tensors);'

            intermediates = ''
            if tensorlist_intermediates_str != '':
                intermediates += tensorlist_intermediates_str + '\n'
            intermediates += f"  std::vector<at::Tensor> xlatens_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};"
            intermediates += "\n  auto xlatens = to_cpu(xlatens_tensors);"
            if opt_tensor_intermediates_str != '':
                intermediates += opt_tensor_intermediates_str

            is_method = Variant.function not in f.variants
            func_name = f'AtenXlaTypeDefault::{name}'

            # Gather all of the updated variable names to call into the CPU operator.
            # Just use the original binding names for inputs where we didn't create explicit intermediate variables.
            updated_bindings: List[str] = [
                tensorlist_args.get(
                    a, opt_tensor_args.get(a, tensor_args.get(a, a.name)))
                for a in dispatcher_order_args
            ]

            at_call_name = CppSignatureGroup.from_native_function(
                f, method=is_method).most_faithful_signature().name()

            # Notice that we don't need to perform a translate: we're technically going from the dispatcher API
            # to the faithful C++ API, which are carefuly written to be exactly the same.
            cpu_result_name = 'x_result'
            if is_method:
                at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});'
            else:
                at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});'
            avoid_warning = ''
            if f.func.returns:
                at_call = f'auto&& {cpu_result_name} = {at_call}'
                avoid_warning = f'\n  static_cast<void>({cpu_result_name}); // Avoid warnings in case not used'

            collect_mutated_tensors = ''
            update_tensors = ''
            if len(annotated_tensor_indices) > 0:
                indices_str = ", ".join(
                    [str(i) for i in annotated_tensor_indices])
                collect_mutated_tensors = f'\n  std::vector<size_t> xlatens_update_indices = {{{indices_str}}};'
                # TODO: uncomment the resize line below. Taken out temporarily for testing
                update_tensors = '''
  for (int i : xlatens_update_indices) {
    // if (xlatens_tensors[i].sizes() != xlatens[i].sizes()) xlatens_tensors[i].resize_(xlatens[i].sizes());
    at::_copy_from_and_resize(xlatens[i], xlatens_tensors[i]);
  }
'''

            returns = ''
            if f.func.returns:
                ret_names = cpp.return_names(f, fallback_name=cpu_result_name)
                if len(ret_names) == 1:
                    returns = xla_tensor_creation_api(
                        ret_names[0],
                        f.func.returns[0],
                        get_device_param(dispatcher_order_args),
                        cpu_result_name=cpu_result_name)
                else:
                    return_args = [
                        xla_tensor_creation_api(
                            ret_names[i],
                            f.func.returns[i],
                            get_device_param(dispatcher_order_args),
                            cpu_result_name=f'std::get<{i}>({cpu_result_name})'
                        ) for i in range(len(f.func.returns))
                    ]
                    returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})'
            return_str = ''
            if returns != '':
                return_str = f'\n  return {returns};'

            return f"""\