Пример #1
0
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]:
    if isinstance(a, Argument):
        return [Binding(
            ctype=argument_type(a, binds=a.name),
            name=a.name,
            default=cpp.default_expr(a.default, a.type) if a.default is not None else None,
            argument=a,
        )]
    elif isinstance(a, SelfArgument):
        # Erase SelfArgument from the distinction
        return argument(a.argument)
    elif isinstance(a, TensorOptionsArguments):
        if local.use_c10_dispatcher() in [UseC10Dispatcher.hacky_wrapper_for_legacy_signatures,
                                          UseC10Dispatcher.with_codegenerated_unboxing_wrapper]:
            # TODO: expunge this logic entirely
            default = None
            if all(x.default == "None" for x in a.all()):
                default = '{}'
            elif a.dtype.default == "long":
                default = 'at::kLong'  # TODO: this is wrong
            return [Binding(
                ctype=ConstRefCType(BaseCType('TensorOptions', 'options')),
                name='options',
                default=default,
                argument=a,
            )]
        else:
            assert local.use_c10_dispatcher() == UseC10Dispatcher.full
            return [
                Binding(
                    ctype=OptionalCType(BaseCType('ScalarType', 'dtype')),
                    name='dtype',
                    default='{}',
                    argument=a,
                ),
                Binding(
                    ctype=OptionalCType(BaseCType('Layout', 'layout')),
                    name='layout',
                    default='{}',
                    argument=a,
                ),
                Binding(
                    ctype=OptionalCType(BaseCType('Device', 'device')),
                    name='device',
                    default='{}',
                    argument=a,
                ),
                Binding(
                    ctype=OptionalCType(BaseCType('bool', 'pin_memory')),
                    name='pin_memory',
                    default='{}',
                    argument=a,
                )]
    else:
        assert_never(a)
Пример #2
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
    # If it's a value type, do the value type translation
    r = valuetype_type(t, binds=binds)
    if r is not None:
        return r

    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            if mutable:
                return MutRefCType(BaseCType('Tensor', binds))
            else:
                return ConstRefCType(BaseCType('Tensor', binds))
        else:
            raise AssertionError(f"base type should have been value type {t}")
    elif isinstance(t, OptionalType):
        if str(t.elem) == 'Tensor':
            if mutable:
                return MutRefCType(BaseCType(
                    'Tensor', binds))  # TODO: fix this discrepancy
            else:
                if local.use_c10_dispatcher().dispatcher_uses_new_style():
                    return ConstRefCType(
                        OptionalCType(BaseCType('Tensor', binds)))
                else:
                    return ConstRefCType(BaseCType('Tensor', binds))
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        return OptionalCType(elem)
    elif isinstance(t, ListType):
        # TODO: remove these special cases, ArrayRef fallthrough works fine
        # NB: CType throws away ArrayRef structure because it is not currently
        # relevant in translation.  When it becomes relevant, need to add back
        if str(t.elem) == 'int':
            return BaseCType("IntArrayRef", binds)
        elif str(t.elem) == 'Tensor':
            return BaseCType("TensorList", binds)
        elif str(t.elem) == 'Dimname':
            return BaseCType("DimnameList", binds)
        elif str(t.elem) == 'Tensor?':
            if local.use_c10_dispatcher().dispatcher_uses_new_style():
                return BaseCType("const c10::List<c10::optional<Tensor>> &",
                                 binds)
            else:
                return BaseCType("TensorList", binds)
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
        # TODO: explicitly qualify namespace here
        return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds)
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")
Пример #3
0
def arguments(func: FunctionSchema) -> Sequence[DispatcherArgument]:
    if local.use_c10_dispatcher() is UseC10Dispatcher.full:
        return list(map(argument, itertools.chain(func.out_arguments, func.arguments, func.kwarg_only_arguments)))
    else:
        return [
            DispatcherArgument(type=la.type, name=la.name, argument=la.argument)
            for la in legacy_dispatcher.arguments(func)
        ]
Пример #4
0
def arguments(func: FunctionSchema) -> List[Binding]:
    args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
    if local.use_c10_dispatcher() is UseC10Dispatcher.full:
        args.extend(func.arguments.non_out)
        args.extend(func.arguments.out)
    else:
        args.extend(func.arguments.out)
        args.extend(func.arguments.non_out)
    return [r for arg in args for r in argument(arg)]
Пример #5
0
    def go(f: NativeFunction) -> Optional[str]:
        if str(f.func.name.name).endswith('_like') or str(f.func.name.name).startswith('new_'):
            return None

        name = legacy_dispatcher.name(f.func)
        legacy_dispatcher_returns_type = legacy_dispatcher.returns_type(f.func.returns)
        legacy_dispatcher_args = legacy_dispatcher.arguments(f.func)

        if not any(isinstance(a.argument, TensorOptionsArguments) for a in legacy_dispatcher_args):
            return None

        legacy_dispatcher_tensor_args = [
            a for a in legacy_dispatcher_args
            if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
        ]

        dispatcher_returns_type = dispatcher.returns_type(f.func.returns)
        dispatcher_args = dispatcher.arguments(f.func)
        dispatcher_exprs = dispatcher.legacydispatcherarguments_exprs(legacy_dispatcher_args)

        if target is Target.DEFINITION:
            # See Note [Byte-for-byte compatibility]
            # I don't think there's actually a good reason to generate
            # these two cases differently
            if legacy_dispatcher_tensor_args:
                tensor_args = ', '.join(a.name for a in legacy_dispatcher_tensor_args)
                compute_dk = f"""\
DispatchKeySet _dk_set = DispatchKeySet(options.computeDispatchKey()) | c10::detail::multi_dispatch_key_set({tensor_args});
  DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
  DispatchKey _dk = c10::impl::dispatchTypeId(_dk_set, _dk_mask);"""
            else:
                compute_dk = "DispatchKey _dk = options.computeDispatchKey();"
            return f"""\
// aten::{f.func}
{legacy_dispatcher_returns_type} {name}({', '.join(a.str_with_default() for a in legacy_dispatcher_args)}) {{
  static auto op = c10::Dispatcher::singleton()
    .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
    .typed<{dispatcher_returns_type} ({', '.join(a.type for a in dispatcher_args)})>();
  {compute_dk}
  DispatchKey _autograd_dk = c10::getAutogradKeyFromBackend(_dk);
  // This trick allows calling Autograd backend kernel first and then backend kernel,
  // without adding another AutogradBackendSelect dispatch key.
  DispatchKey _current_dk = at::impl::variable_excluded_from_dispatch() ? _dk : _autograd_dk;
  return op.callWithDispatchKey(_current_dk, {', '.join(a.expr for a in dispatcher_exprs)});
}}
"""
        elif target is Target.REGISTRATION:
            if local.use_c10_dispatcher() is UseC10Dispatcher.full:
                return f"""m.impl("aten::{f.func.name}",
          c10::impl::hacky_wrapper_for_legacy_signatures<{dispatcher_returns_type} ({', '.join(a.type for a in dispatcher_args)})>(
            TORCH_FN({name})));"""
            else:
                return f"""m.impl_UNBOXED("aten::{f.func.name}", {name});"""
        elif target is Target.DECLARATION:
            raise AssertionError()
        else:
            assert_never(target)
Пример #6
0
def arguments(func: FunctionSchema) -> Tuple[NativeArgument, ...]:
    args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
    if local.use_c10_dispatcher() is UseC10Dispatcher.full:
        args.extend(func.arguments.non_out)
        args.extend(func.arguments.out)
    else:
        args.extend(func.arguments.out)
        args.extend(func.arguments.non_out)
    return tuple(i for arg in args for i in argument(arg))
Пример #7
0
def argumenttype_type(t: Type, *, mutable: bool) -> str:
    # If it's a value type, do the value type translation
    r = valuetype_type(t)
    if r is not None:
        return r

    if str(t) == 'Tensor' and mutable and local.hack_const_mutable_self():
        return 'const Tensor &'

    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            if mutable:
                return 'Tensor &'
            else:
                return 'const Tensor &'
        else:
            raise AssertionError(f"base type should have been value type {t}")
    elif isinstance(t, OptionalType):
        if str(t.elem) == 'Tensor':
            if mutable:
                return 'Tensor &'  # TODO: fix this discrepancy
            else:
                if local.use_c10_dispatcher() is UseC10Dispatcher.full:
                    return 'const c10::optional<Tensor>&'
                else:
                    return 'const Tensor &'
        elem = argumenttype_type(t.elem, mutable=mutable)
        return f"c10::optional<{elem}>"
    elif isinstance(t, ListType):
        # TODO: remove these special cases, ArrayRef fallthrough works fine
        if str(t.elem) == 'int':
            return "IntArrayRef"
        elif str(t.elem) == 'Tensor':
            return "TensorList"
        elif str(t.elem) == 'Dimname':
            return "DimnameList"
        # TODO: do something reasonable about lists of optional tensors
        elif not local.use_c10_dispatcher() is UseC10Dispatcher.full and str(t.elem) == 'Tensor?':
            return "TensorList"
        elem = argumenttype_type(t.elem, mutable=mutable)
        # TODO: explicitly qualify namespace here
        return f"ArrayRef<{elem}>"
    else:
        raise AssertionError(f"unrecognized type {repr(t)}")
Пример #8
0
def arguments(func: FunctionSchema) -> List[Binding]:
    if local.use_c10_dispatcher().dispatcher_uses_new_style():
        return [
            r
            for a in itertools.chain(func.arguments.positional, func.arguments.
                                     kwarg_only, func.arguments.out)
            for r in argument(a)
        ]
    else:
        return native.arguments(func)
Пример #9
0
def argument_type_str(t: Type, *, simple_type: bool = False) -> str:
    if isinstance(t, BaseType):
        if t.name == BaseTy.Tensor:
            return 'Tensor'
        elif t.name == BaseTy.int:
            return 'int64_t'
        elif t.name == BaseTy.float:
            return 'double'
        elif t.name == BaseTy.str:
            return 'std::string'
        elif t.name in [
                BaseTy.bool, BaseTy.QScheme, BaseTy.Scalar, BaseTy.ScalarType,
                BaseTy.Generator, BaseTy.Storage, BaseTy.Layout, BaseTy.Device,
                BaseTy.MemoryFormat, BaseTy.Dimname, BaseTy.Stream,
                BaseTy.ConstQuantizerPtr
        ]:
            # These python schema type names line up with their function schema names
            return t.name.name

    elif isinstance(t, OptionalType):
        if str(t.elem) == 'Tensor':
            if not simple_type or local.use_c10_dispatcher(
            ).dispatcher_uses_new_style():
                # Is it desired to keep '?' for simple_type with new style dispatcher?
                return 'Tensor?'
            else:
                return 'Tensor'
        elem = argument_type_str(t.elem, simple_type=simple_type)
        if elem == 'Layout':
            # TODO: fix this special case in PythonArgParser?
            return 'Layout'
        else:
            return f'{elem}?'

    elif isinstance(t, ListType):
        size = t.size if not simple_type else None
        if str(t.elem) == 'bool':
            assert t.size is not None
            return f'std::array<bool,{t.size}>'
        elif str(t.elem) == 'int':
            return f'IntArrayRef[{size}]' if size is not None else 'IntArrayRef'
        elif str(t.elem) == 'Tensor':
            return f'TensorList[{size}]' if size is not None else 'TensorList'
        elif str(t.elem) == 'Tensor?':
            if simple_type:
                return 'TensorList'
            else:
                # TODO: clone the old codegen behavior but does it make sense?
                return 'TensorList?'
        elif str(t.elem) == 'Dimname':
            return f'DimnameList[{size}]' if size is not None else 'DimnameList'
        elem = argument_type_str(t.elem, simple_type=simple_type)
        return f'ArrayRef<{elem}>'

    raise RuntimeError(f'unrecognized type {repr(t)}')
Пример #10
0
def exprs(args: Sequence[DispatcherArgument]) -> Sequence[DispatcherExpr]:
    if local.use_c10_dispatcher() is UseC10Dispatcher.full:
        process_tensoroptions = ProcessTensoroptions.SCATTER
    else:
        process_tensoroptions = ProcessTensoroptions.PASS_THROUGH
    return cpparguments_exprs([
        CppArgument(
            type=a.type, name=a.name, default=None, argument=a.argument)
        for a in args
    ],
                              process_tensoroptions=process_tensoroptions)
Пример #11
0
def cppargument_exprs(
    a: CppArgument,
    *,
    tensor_options: Optional[CppArgument],
    process_tensoroptions: ProcessTensoroptions = ProcessTensoroptions.
    PASS_THROUGH
) -> Sequence[DispatcherExpr]:
    if isinstance(a.argument, TensorOptionsArguments):
        if process_tensoroptions == ProcessTensoroptions.SCATTER:
            ta = a.argument
            return [
                DispatcherExpr(
                    type=argument_type(ta.dtype),
                    expr=f'optTypeMetaToScalarType({a.name}.dtype_opt())'),
                DispatcherExpr(type=argument_type(ta.layout),
                               expr=f'{a.name}.layout_opt()'),
                DispatcherExpr(type=argument_type(ta.device),
                               expr=f'{a.name}.device_opt()'),
                DispatcherExpr(
                    type=argument_type(ta.pin_memory),
                    expr=f'{a.name}.pinned_memory_opt()'),  # weird discrep
            ]
        elif process_tensoroptions == ProcessTensoroptions.GATHER:
            return [
                DispatcherExpr(
                    type='const TensorOptions &',
                    expr=
                    "TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory)"
                )
            ]
        else:
            assert process_tensoroptions == ProcessTensoroptions.PASS_THROUGH
            return [DispatcherExpr(type='const TensorOptions &', expr=a.name)]
    elif isinstance(a.argument, ThisArgument):
        return [
            DispatcherExpr(type=argument_type(a.argument.argument),
                           expr=a.name)
        ]
    elif isinstance(a.argument, Argument):
        if a.name == 'memory_format' and tensor_options is not None and local.use_c10_dispatcher(
        ) is UseC10Dispatcher.full:
            return [
                DispatcherExpr(
                    type=argument_type(a.argument),
                    expr=
                    f'c10::impl::check_tensor_options_and_extract_memory_format({tensor_options.name}, {a.name})'
                )
            ]
        else:
            return [
                DispatcherExpr(type=argument_type(a.argument), expr=a.name)
            ]
    else:
        assert_never(a.argument)
Пример #12
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
    if local.use_c10_dispatcher().dispatcher_uses_new_style():
        # This is a faux amis.  If it makes sense in the future to add
        # more special cases here, or invert things so cpp.argument_type
        # calls this, or just completely inline the function, please do
        # it.
        return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
    else:
        # This is real sharing.  If you're modifying this path, ask
        # yourself why you are changing the native functions protocol
        # here and not in native.
        return native.argumenttype_type(t, mutable=mutable, binds=binds)
Пример #13
0
def argumenttype_type(t: Type, *, mutable: bool) -> str:
    if local.use_c10_dispatcher() is UseC10Dispatcher.full:
        # This is a faux amis.  If it makes sense in the future to add
        # more special cases here, or invert things so cpp.argument_type
        # calls this, or just completely inline the function, please do
        # it.
        return cpp.argumenttype_type(t, mutable=mutable)
    else:
        # This is real sharing.  If you're modifying this path, ask
        # yourself why you are changing the legacy dispatcher protocol
        # here and not in legacy_dispatcher.
        return legacy_dispatcher.argumenttype_type(t, mutable=mutable)
Пример #14
0
def arguments(func: FunctionSchema) -> Tuple[DispatcherArgument, ...]:
    if local.use_c10_dispatcher().dispatcher_uses_new_style():
        return tuple(
            map(
                argument,
                itertools.chain(func.out_arguments, func.arguments,
                                func.kwarg_only_arguments)))
    else:
        return tuple(
            DispatcherArgument(
                type=la.type, name=la.name, argument=la.argument)
            for la in native.arguments(func))
Пример #15
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
    if str(t) == 'Tensor?':
        tensor_type: CType = BaseCType('Tensor', binds)
        if local.use_c10_dispatcher(
        ) is not UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
            tensor_type = OptionalCType(tensor_type)
        if mutable:
            return MutRefCType(tensor_type)
        else:
            return ConstRefCType(tensor_type)
    elif str(t) == 'Tensor?[]':
        return ConstRefCType(
            BaseCType("c10::List<c10::optional<Tensor>>", binds))
    return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
Пример #16
0
def argument(a: Argument) -> DispatcherArgument:
    if local.use_c10_dispatcher() is UseC10Dispatcher.full:
        return DispatcherArgument(
            type=argument_type(a),
            name=a.name,
            argument=a,
        )
    else:
        la = legacy_dispatcher.argument(a)
        return DispatcherArgument(
            type=la.type,
            name=la.name,
            argument=la.argument,
        )
Пример #17
0
def argument(a: Argument) -> DispatcherArgument:
    if local.use_c10_dispatcher().dispatcher_uses_new_style():
        return DispatcherArgument(
            type=argument_type(a),
            name=a.name,
            argument=a,
        )
    else:
        la = native.argument(a)
        return DispatcherArgument(
            type=la.type,
            name=la.name,
            argument=la.argument,
        )
Пример #18
0
def cpparguments_exprs(func: FunctionSchema, *, method: bool,
                       api_is_faithful: bool) -> Sequence[DispatcherExpr]:
    dispatcher_calling_convention_is_faithful = local.use_c10_dispatcher(
    ).dispatcher_uses_new_style()
    arguments = cpp.group_arguments(
        func,
        method=method,
        faithful=dispatcher_calling_convention_is_faithful)

    if api_is_faithful:
        argument_packs = tuple(cpp.argument_faithful(a) for a in arguments)
    else:
        argument_packs = tuple(cpp.argument(a) for a in arguments)

    return _cpparguments_exprs(argument_packs)
Пример #19
0
def argument(a: Argument) -> DispatcherArgument:
    if local.use_c10_dispatcher().dispatcher_uses_new_style():
        return DispatcherArgument(
            type=argument_type(a),
            name=a.name,
            argument=a,
        )
    else:
        la = native.argument(a)
        assert len(
            la
        ) == 1, "Operators using the legacy signature in the dispatcher don't scatter TensorOptions."
        return DispatcherArgument(
            type=la[0].type,
            name=la[0].name,
            argument=la[0].argument,
        )
Пример #20
0
    def go(f: NativeFunction) -> Optional[str]:
        if f.manual_kernel_registration:
            return None
        if Variant.function not in f.variants:
            return None

        name = cpp.name(f.func)

        sig_group = CppSignatureGroup.from_schema(f.func, method=False)

        if target is Target.DECLARATION:
            result = f"CAFFE2_API {sig_group.signature.decl()};\n"
            if sig_group.faithful_signature is not None:
                result += f"CAFFE2_API {sig_group.faithful_signature.decl()};\n"
            return result

        assert target is Target.DEFINITION

        def generate_defn(sig: CppSignature) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            dispatcher_exprs = dispatcher.cpparguments_exprs(
                sig.argument_packs())
            dispatcher_exprs_str = ', '.join(
                map(lambda a: a.expr, dispatcher_exprs))

            return f"""
// aten::{f.func}
{sig.defn()} {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{dispatcher_sig.type()}>();
    return op.call({dispatcher_exprs_str});
}}
"""

        result = generate_defn(sig_group.signature)
        if sig_group.faithful_signature is not None:
            if local.use_c10_dispatcher().dispatcher_uses_new_style():
                result += generate_defn(sig_group.faithful_signature)

        return result
Пример #21
0
def cpparguments_exprs(func: FunctionSchema, *, method: bool,
                       api_is_faithful: bool) -> Sequence[DispatcherExpr]:
    dispatcher_is_faithful = local.use_c10_dispatcher(
    ).dispatcher_uses_new_style()

    arguments: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
    if dispatcher_is_faithful:
        arguments.extend(func.arguments.non_out)
        arguments.extend(func.arguments.out)
    else:
        arguments.extend(func.arguments.out)
        arguments.extend(func.arguments.non_out)

    if api_is_faithful:
        argument_packs = tuple(
            cpp.argument_faithful(a, method=method) for a in arguments)
    else:
        argument_packs = tuple(
            cpp.argument(a, method=method) for a in arguments)

    return _cpparguments_exprs(argument_packs)
Пример #22
0
def argument(
        a: Union[Argument, TensorOptionsArguments,
                 SelfArgument]) -> List[Binding]:
    # We could forward to native.argument but it is a bit suspect because
    # the grouping may not be set correctly
    assert local.use_c10_dispatcher().dispatcher_uses_new_style()

    if isinstance(a, Argument):
        return [
            Binding(
                ctype=argument_type(a, binds=a.name),
                name=a.name,
                argument=a,
            )
        ]
    elif isinstance(a, SelfArgument):
        return argument(a.argument)
    elif isinstance(a, TensorOptionsArguments):
        return argument(a.dtype) + argument(a.layout) + argument(
            a.device) + argument(a.pin_memory)
    else:
        assert_never(a)
Пример #23
0
    def func(f: NativeFunction) -> Optional[str]:
        if dispatch is not None:
            if f.dispatch is None or dispatch not in f.dispatch:
                return None
        else:
            if f.dispatch is not None and target is not Target.REGISTRATION:
                return None

        if op_registration_whitelist is not None and \
                f"aten::{f.func.name.name}" not in op_registration_whitelist and target is Target.REGISTRATION:
            return None

        name = native.name(f.func)
        returns_type = native.returns_type(f.func.returns)
        args = native.arguments(f.func)
        args_str = ', '.join(map(str, args))

        if target is Target.DECLARATION:
            return f"{returns_type} {name}({args_str});"
        elif target is Target.DEFINITION:
            if f.dispatch is None:
                cpp_name = cpp.name(f.func)
                impl_name = f"at::native::{cpp_name}"
            else:
                assert dispatch is not None
                impl_name = f"at::native::{f.dispatch[dispatch]}"

            args_exprs_str = ', '.join(map(lambda a: a.name, args))

            return_kw = "    return "

            cuda_guard = ""
            if dispatch is None or 'CUDA' in dispatch or 'Vulkan' == dispatch:
                self_args = (a for a in f.func.arguments if a.name == "self")

                # There is precedence for which argument we use to do
                # device guard.  This describes the precedence order.
                candidate_args = itertools.chain(self_args,
                                                 f.func.out_arguments,
                                                 f.func.arguments)

                # Only tensor like arguments are eligible
                device_of = next(
                    (f'{a.name}'
                     for a in candidate_args if a.type.is_tensor_like()), None)

                has_tensor_options = any(
                    isinstance(a.argument, TensorOptionsArguments)
                    for a in args)

                # TODO: There is probably a simpler version of this that
                # works just as well.
                if f.device_guard and (dispatch is None or 'Vulkan'
                                       == dispatch) and has_tensor_options:
                    cuda_guard = """\
    const DeviceGuard device_guard(options.device());
"""
                elif f.device_guard and dispatch is not None and 'CUDA' in dispatch and has_tensor_options:
                    cuda_guard = """\
    globalContext().lazyInitCUDA();
    const DeviceGuard device_guard(options.device());
"""
                elif f.device_guard and device_of is not None:
                    cuda_guard = f"""\
    const OptionalDeviceGuard device_guard(device_of({device_of}));
"""
                else:
                    cuda_guard = """\
    // DeviceGuard omitted
"""

            return f"""\
{returns_type} {name}({args_str}) {{
{cuda_guard}{return_kw}{impl_name}({args_exprs_str});
}}
"""

        elif target is Target.REGISTRATION:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            if dispatch is None or dispatch == 'Math' or dispatch == 'DefaultBackend':
                type_name = f'TypeDefault::{name}'
            else:
                type_name = f'{dispatch}Type::{name}'

            # def registration only happens in TypeDefault
            def_registration = ""
            if dispatch is None:
                def_registration = f'm.def({cpp_string(str(f.func))});\n'

            impl_registration = ""
            if not def_only and not f.manual_kernel_registration and (
                    dispatch is not None or f.dispatch is None):
                # Figure out which signature the function is
                if local.use_c10_dispatcher() is UseC10Dispatcher.full:
                    payload = f"TORCH_FN({type_name})"
                elif local.use_c10_dispatcher(
                ) is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
                    payload = "c10::impl::hacky_wrapper_for_legacy_signatures<" \
                        f"{dispatcher_sig.type()}>(TORCH_FN({type_name}))"

                else:
                    assert local.use_c10_dispatcher(
                    ) is UseC10Dispatcher.with_codegenerated_unboxing_wrapper
                    payload = f"torch::CppFunction::makeUnboxedOnly(&{type_name})"

                # Annotate it with dispatch information if necessary
                #
                # NB: In the ordinary, TypeDerived code generation work flow, specification
                # of the backend is handled by the enclosing block, so the torch::dispatch
                # invocation here is strictly unnecessary.  However, in the fbcode mobile
                # only workflow using per-op registration, these registrations will get dumped
                # in a TORCH_LIBRARY_FRAGMENT that does not have an ambient backend.  So
                # the torch::dispatch specification here is important!  See
                # Note [Redundancy in registration code is OK] for how we handle redundant info.
                if dispatch is not None:
                    payload = f"torch::dispatch(DispatchKey::{dispatch},\n{payload})\n"

                impl_registration = f'm.impl("{f.func.name}",\n{payload});\n'

            return f"{def_registration}{impl_registration}"
        else:
            assert_never(target)
Пример #24
0
    def __call__(self, f: NativeFunction) -> Optional[str]:
        # for mypy type refinement; would be fixed by TODO on target
        assert self.target is not Target.DECLARATION

        if self.dispatch_key not in f.dispatch:
            return None

        op_name = f"aten::{f.func.name}"
        if self.target is Target.REGISTRATION and not self.selector.is_operator_selected(
                op_name):
            return None

        name = native.name(f.func)
        returns_type = native.returns_type(f.func.returns)
        args = native.arguments(f.func)
        args_str = ', '.join(map(str, args))

        if self.target is Target.DEFINITION:
            impl_name = f"at::native::{f.dispatch[self.dispatch_key]}"

            args_exprs_str = ', '.join(a.name for a in args)

            return_kw = "    return "

            cuda_guard = ""
            if is_generic_dispatch_key(
                    self.dispatch_key) or is_cuda_dispatch_key(
                        self.dispatch_key):
                self_args = (a for a in f.func.arguments if a.name == "self")

                # There is precedence for which argument we use to do
                # device guard.  This describes the precedence order.
                candidate_args = itertools.chain(self_args,
                                                 f.func.out_arguments,
                                                 f.func.arguments)

                # Only tensor like arguments are eligible
                device_of = next(
                    (f'{a.name}'
                     for a in candidate_args if a.type.is_tensor_like()), None)

                has_tensor_options = any(
                    isinstance(a.argument, TensorOptionsArguments)
                    for a in args)

                if local.use_c10_dispatcher() == UseC10Dispatcher.full:
                    cuda_guard_from_tensor_options = """\
    const DeviceGuard device_guard(device_or_default(device));
"""
                else:
                    assert local.use_c10_dispatcher() in [
                        UseC10Dispatcher.with_codegenerated_unboxing_wrapper,
                        UseC10Dispatcher.hacky_wrapper_for_legacy_signatures
                    ]
                    cuda_guard_from_tensor_options = """\
    const DeviceGuard device_guard(options.device());
"""

                # TODO: There is probably a simpler version of this that
                # works just as well.
                if f.device_guard and is_generic_dispatch_key(
                        self.dispatch_key) and has_tensor_options:
                    cuda_guard = cuda_guard_from_tensor_options
                elif f.device_guard and is_cuda_dispatch_key(
                        self.dispatch_key) and has_tensor_options:
                    cuda_guard = f"""\
    globalContext().lazyInitCUDA();
    {cuda_guard_from_tensor_options}
"""
                elif f.device_guard and device_of is not None:
                    cuda_guard = f"""\
    const OptionalDeviceGuard device_guard(device_of({device_of}));
"""
                else:
                    cuda_guard = """\
    // DeviceGuard omitted
"""

            return f"""\
{returns_type} {name}({args_str}) {{
{cuda_guard}{return_kw}{impl_name}({args_exprs_str});
}}
"""

        elif self.target is Target.REGISTRATION:
            if f.manual_kernel_registration:
                return None
            else:
                dispatcher_sig = DispatcherSignature.from_schema(f.func)

                # Figure out which signature the function is
                if local.use_c10_dispatcher() is UseC10Dispatcher.full:
                    payload = f"TORCH_FN({name})"
                elif local.use_c10_dispatcher(
                ) is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
                    payload = "c10::impl::hacky_wrapper_for_legacy_signatures<" \
                        f"{dispatcher_sig.type()}>(TORCH_FN({name}))"

                else:
                    assert local.use_c10_dispatcher(
                    ) is UseC10Dispatcher.with_codegenerated_unboxing_wrapper
                    payload = f"torch::CppFunction::makeUnboxedOnly(&{name})"

                return f'm.impl("{f.func.name}",\n{payload});\n'
        else:
            assert_never(self.target)
Пример #25
0
    def gen_one(self, f: NativeFunction) -> Optional[str]:
        assert not f.manual_kernel_registration

        if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(
                f):
            return None

        # Note [Direct dispatch bindings]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Signature of the non-dispatched function we'll expose in a header
        # (e.g., at::cpu::add).  We don't generate methods (TODO: do this
        # when CPUTensor class is a thing); nor do we generate fallback
        # bindings for manual_cpp_binding functions.
        cpp_sig_group = CppSignatureGroup.from_native_function(
            f, method=False, fallback_binding=False)

        # Signature of the wrapper function we'll register to the dispatcher
        sig = NativeSignature(f.func, prefix="wrapper_")

        if self.target is Target.NAMESPACED_DECLARATION:
            result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
            if cpp_sig_group.faithful_signature is not None:
                result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
            return result

        elif self.target is Target.NAMESPACED_DEFINITION:

            def generate_defn(cpp_sig: CppSignature) -> str:
                return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

            result = generate_defn(cpp_sig_group.signature)
            if cpp_sig_group.faithful_signature is not None:
                result += generate_defn(cpp_sig_group.faithful_signature)
            return result

        elif self.target is Target.ANONYMOUS_DEFINITION:

            k = f.func.kind()

            # Construct the body of the wrapper function with signature sig
            sig_body = []
            # We'll use context to keep track of any variables we've brought
            # into scope while generating code
            context: List[Union[Binding, Expr]] = list(sig.arguments())

            # Initialize the class corresponding to this structured
            # operator; feeding it the output argument(s) if it is known
            if self.dispatch_key == DispatchKey.Meta:
                class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
                parent_class = f"at::meta::{meta.name(self.g)}"
            else:
                class_name = f"structured_{self.g.out.dispatch[self.dispatch_key]}_{k.name}"
                parent_class = f"at::native::structured_{self.g.out.dispatch[self.dispatch_key]}"

            if k is SchemaKind.functional:
                assert len(
                    f.func.returns) == 1, "multi-return not supported yet"
                sig_body.append(f"{class_name} op;")
            elif k is SchemaKind.inplace:
                sig_body.append(f"{class_name} op(self);")
            elif k is SchemaKind.out:
                assert len(f.func.arguments.out
                           ) == 1, "multi-out structured not supported yet"
                sig_body.append(
                    f"{class_name} op({f.func.arguments.out[0].name});")

            # Translate the input native arguments into structured
            # arguments for the meta call
            meta_exprs = ', '.join(e.expr for e in translate(
                context, structured.meta_arguments(self.g), method=False))
            sig_body.append(f"op.meta({meta_exprs});")

            # After running meta, op.outputs_ is guaranteed to be valid;
            # add it to the context
            # TODO: handle multi-return
            context.append(
                Expr(
                    expr="op.outputs_[0]",
                    type=structured.out_arguments(self.g)[0].ctype,
                ))

            # With the expanded context, do the impl call (if not a meta
            # function)
            if self.dispatch_key != DispatchKey.Meta:
                impl_exprs = ', '.join(e.expr for e in translate(
                    context, structured.impl_arguments(self.g), method=False))
                sig_body.append(f"op.impl({impl_exprs});")

            # Destructively return the final tensors
            if k is SchemaKind.functional:
                assert len(
                    f.func.returns) == 1, "multi-return not supported yet"
                ret_expr = "std::move(op.outputs_[0])"  # small optimization
            elif k is SchemaKind.inplace:
                ret_expr = "self"
            elif k is SchemaKind.out:
                assert len(f.func.arguments.out
                           ) == 1, "multi-out structured not supported yet"
                ret_expr = f.func.arguments.out[0].name
            sig_body.append(f"return {ret_expr};")

            sig_body_str = "\n".join(sig_body)

            # For an overview of what this template code looks like, see
            # https://github.com/pytorch/rfcs/pull/9
            return f"""\
{self.gen_class(
f, k,
class_name=class_name,
parent_class=parent_class,
generate_super=self.g.out.structured_inherits is not None
)}

{sig.defn()} {{
{sig_body_str}
}}
"""

        elif self.target is Target.REGISTRATION:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            assert local.use_c10_dispatcher() is UseC10Dispatcher.full
            return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
        else:
            assert_never(self.target)
            # Silence mypy's "Missing return statement" error
            return None
Пример #26
0
    def gen_one(self, f: NativeFunction) -> Optional[str]:
        assert not f.manual_kernel_registration

        if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(
                f):
            return None

        # TODO: Now, there is something interesting going on here.  In the code below,
        # we generate CompositeExplicitAutograd implementations of functional and inplace
        # based on the out implementation.  But in fact, out is definable by
        # functional too (just not very efficiently), and this is honestly the
        # MORE likely situation for a backend implementor.  How do we pick?
        # Well, taking a page from Haskell type classes and default methods,
        # we could conceivably register a circular definition (out in terms
        # of functional, and functional in terms of out) and just require
        # someone to implement one or the other.  We'd have to do a little bit
        # of work to not register one of these "weak" definitions unless there
        # is a strong definition somewhere in the DAG!  So it's not implemented yet.
        if self.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind(
        ) is SchemaKind.out:
            # Never generate a default implementation for out, that's what you
            # have to define as a backend implementor
            return None

        # Note [Direct dispatch bindings]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Signature of the non-dispatched function we'll expose in a header
        # (e.g., at::cpu::add).  We don't generate methods (TODO: do this
        # when CPUTensor class is a thing); nor do we generate fallback
        # bindings for manual_cpp_binding functions.
        cpp_sig_group = CppSignatureGroup.from_native_function(
            f, method=False, fallback_binding=False)

        # Signature of the wrapper function we'll register to the dispatcher
        sig = NativeSignature(f.func, prefix="wrapper_")

        if self.target is Target.NAMESPACED_DECLARATION:
            result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
            if cpp_sig_group.faithful_signature is not None:
                result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
            return result

        elif self.target is Target.NAMESPACED_DEFINITION:

            def generate_defn(cpp_sig: CppSignature) -> str:
                return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

            result = generate_defn(cpp_sig_group.signature)
            if cpp_sig_group.faithful_signature is not None:
                result += generate_defn(cpp_sig_group.faithful_signature)
            return result

        elif self.target is Target.ANONYMOUS_DEFINITION:

            k = f.func.kind()

            # Construct the body of the wrapper function with signature sig
            sig_body = []
            # We'll use context to keep track of any variables we've brought
            # into scope while generating code
            context: List[Union[Binding, Expr]] = list(sig.arguments())

            # Initialize the class corresponding to this structured
            # operator; feeding it the output argument(s) if it is known
            if self.dispatch_key is DispatchKey.Meta:
                class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
                parent_class = f"at::meta::{meta.name(self.g)}"
            elif self.dispatch_key is DispatchKey.CompositeExplicitAutograd:
                # TODO: dedup this branch
                class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
                parent_class = f"at::meta::{meta.name(self.g)}"
            else:
                class_name = f"structured_{self.g.out.dispatch[self.dispatch_key]}_{k.name}"
                parent_class = f"at::native::structured_{self.g.out.dispatch[self.dispatch_key]}"

            if k is SchemaKind.functional:
                assert len(
                    f.func.returns) == 1, "multi-return not supported yet"
                sig_body.append(f"{class_name} op;")
            elif k is SchemaKind.inplace:
                sig_body.append(f"{class_name} op(self);")
            elif k is SchemaKind.out:
                assert len(f.func.arguments.out
                           ) == 1, "multi-out structured not supported yet"
                sig_body.append(
                    f"{class_name} op({f.func.arguments.out[0].name});")

            # Translate the input native arguments into structured
            # arguments for the meta call
            meta_exprs = ', '.join(e.expr for e in translate(
                context, structured.meta_arguments(self.g), method=False))
            sig_body.append(f"op.meta({meta_exprs});")

            # After running meta, op.outputs_ is guaranteed to be valid;
            # add it to the context
            # TODO: handle multi-return
            assert ConstRefCType(BaseCType("Tensor", structured.out_arguments(self.g)[0].ctype.name)) == \
                structured.out_arguments(self.g)[0].ctype
            context.append(
                Expr(
                    expr="op.outputs_[0]",
                    # TODO: Stop hardcoding that the output type is a Tensor.  Note
                    # that for the codegen here this is fine because outputs_ is
                    # hardcoded to be tensor already
                    type=MutRefCType(
                        BaseCType(
                            "Tensor",
                            structured.out_arguments(self.g)[0].ctype.name)),
                ))

            # With the expanded context, do the impl call (if not a meta
            # function)
            if self.dispatch_key == DispatchKey.CompositeExplicitAutograd:
                # TODO: https://github.com/pytorch/pytorch/issues/53023
                out_sig_group = CppSignatureGroup.from_native_function(
                    self.g.out,
                    method=False,
                    fallback_binding=f.manual_cpp_binding)
                out_sig = out_sig_group.most_faithful_signature()
                api_name = out_sig.name()
                out_exprs = ', '.join(e.expr for e in translate(
                    context, out_sig.arguments(), method=False))
                # TODO: I think this means structured won't work with method
                # only functions (but maybe you're saved by faithful? iunno.)
                # NB: Originally I wrote this as an at::redispatch call, but
                # I got in trouble because that meant I needed a DispatchKeySet
                # in the wrapper function, which meant I needed a DispatchKeySet
                # in the DispatchKeyFunctions declarations, but the defined API
                # there does NOT permit a dispatch key set.  I think you can
                # probably unwind this by calling some function to do the TLS
                # fetch and get the DispatchKeySet when you don't have it, but
                # I didn't do it for this version
                sig_body.append(f"at::{api_name}({out_exprs});")
            elif self.dispatch_key != DispatchKey.Meta:
                impl_exprs = ', '.join(e.expr for e in translate(
                    context, structured.impl_arguments(self.g), method=False))
                sig_body.append(f"op.impl({impl_exprs});")

            # Destructively return the final tensors
            if k is SchemaKind.functional:
                assert len(
                    f.func.returns) == 1, "multi-return not supported yet"
                ret_expr = "std::move(op.outputs_[0])"  # small optimization
            elif k is SchemaKind.inplace:
                ret_expr = "self"
            elif k is SchemaKind.out:
                assert len(f.func.arguments.out
                           ) == 1, "multi-out structured not supported yet"
                ret_expr = f.func.arguments.out[0].name
            sig_body.append(f"return {ret_expr};")

            sig_body_str = "\n".join(sig_body)

            # For an overview of what this template code looks like, see
            # https://github.com/pytorch/rfcs/pull/9
            return f"""\
{self.gen_class(
f, k,
class_name=class_name,
parent_class=parent_class,
generate_super=self.g.out.structured_inherits is not None
)}

{sig.defn()} {{
{sig_body_str}
}}
"""

        elif self.target is Target.REGISTRATION:
            assert local.use_c10_dispatcher() is UseC10Dispatcher.full
            return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
        else:
            assert_never(self.target)
            # Silence mypy's "Missing return statement" error
            return None
Пример #27
0
    def gen_unstructured(self, f: NativeFunction) -> Optional[str]:
        if self.dispatch_key not in f.dispatch:
            return None
        if f.manual_kernel_registration:
            return None

        if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(
                f):
            return None

        sig = NativeSignature(f.func, prefix='wrapper_')

        name = sig.name()
        returns_type = sig.returns_type()
        args = sig.arguments()
        args_str = ', '.join(a.defn() for a in args)

        # See Note [Direct dispatch bindings]
        cpp_sig_group = CppSignatureGroup.from_native_function(
            f, method=False, fallback_binding=False)

        if self.target is Target.NAMESPACED_DECLARATION:
            result = f"TORCH_API {cpp_sig_group.signature.decl()};\n"
            if cpp_sig_group.faithful_signature is not None:
                result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n"
            return result
        elif self.target is Target.NAMESPACED_DEFINITION:

            def generate_defn(cpp_sig: CppSignature) -> str:
                return f"""
{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
"""

            result = generate_defn(cpp_sig_group.signature)
            if cpp_sig_group.faithful_signature is not None:
                result += generate_defn(cpp_sig_group.faithful_signature)
            return result
        elif self.target is Target.ANONYMOUS_DEFINITION:
            impl_name = f"at::native::{f.dispatch[self.dispatch_key]}"

            args_exprs_str = ', '.join(a.name for a in args)

            return_kw = "    return "

            cuda_guard = ""
            if is_generic_dispatch_key(
                    self.dispatch_key) or is_cuda_dispatch_key(
                        self.dispatch_key):
                self_arg = [f.func.arguments.self_arg.argument
                            ] if f.func.arguments.self_arg is not None else []

                # There is precedence for which argument we use to do
                # device guard.  This describes the precedence order.
                candidate_args = itertools.chain(
                    self_arg, f.func.arguments.out,
                    f.func.arguments.flat_positional)

                # Only tensor like arguments are eligible
                device_of = next(
                    (f'{a.name}'
                     for a in candidate_args if a.type.is_tensor_like()), None)

                has_tensor_options = any(
                    isinstance(a.argument, TensorOptionsArguments)
                    for a in args)

                if local.use_c10_dispatcher() == UseC10Dispatcher.full:
                    cuda_guard_from_tensor_options = """\
    const DeviceGuard device_guard(device_or_default(device));
"""
                else:
                    assert local.use_c10_dispatcher(
                    ) is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures
                    cuda_guard_from_tensor_options = """\
    const DeviceGuard device_guard(options.device());
"""

                # TODO: There is probably a simpler version of this that
                # works just as well.
                if f.device_guard and is_generic_dispatch_key(
                        self.dispatch_key) and has_tensor_options:
                    cuda_guard = cuda_guard_from_tensor_options
                elif f.device_guard and is_cuda_dispatch_key(
                        self.dispatch_key) and has_tensor_options:
                    cuda_guard = f"""\
    globalContext().lazyInitCUDA();
    {cuda_guard_from_tensor_options}
"""
                elif f.device_guard and device_of is not None:
                    cuda_guard = f"""\
    const OptionalDeviceGuard device_guard(device_of({device_of}));
"""
                else:
                    cuda_guard = """\
    // DeviceGuard omitted
"""

            return f"""\
namespace {{

{returns_type} {name}({args_str}) {{
{cuda_guard}{return_kw}{impl_name}({args_exprs_str});
}}

}} // anonymous namespace
"""

        elif self.target is Target.REGISTRATION:
            if f.manual_kernel_registration:
                return None
            else:
                dispatcher_sig = DispatcherSignature.from_schema(f.func)

                # Figure out which signature the function is
                if local.use_c10_dispatcher() is UseC10Dispatcher.full:
                    payload = f"TORCH_FN({name})"
                else:
                    assert local.use_c10_dispatcher(
                    ) is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures
                    payload = f"""
c10::impl::hacky_wrapper_for_legacy_signatures<
    {dispatcher_sig.type()},
    {len(f.func.arguments.out)}
>(TORCH_FN({name}))
"""

                return f'm.impl("{f.func.name}",\n{payload});\n'
        else:
            assert_never(self.target)
Пример #28
0
def arg_parser_unpack_method(t: Type, has_default: bool) -> str:
    if has_default and str(t) not in ('ScalarType', 'Device', 'Layout?'):
        raise RuntimeError(
            f'type \'{t}\' does not supported unpacking with default')

    if isinstance(t, BaseType):
        if t.name in [
                BaseTy.Tensor, BaseTy.Stream, BaseTy.Storage, BaseTy.Scalar,
                BaseTy.Dimname
        ]:
            # These unpack methods line up with their schema names
            return t.name.name.lower()
        elif t.name == BaseTy.ScalarType:
            return 'scalartypeWithDefault' if has_default else 'scalartype'
        elif t.name == BaseTy.Device:
            return 'deviceWithDefault' if has_default else 'device'
        elif t.name == BaseTy.int:
            return 'toInt64'
        elif t.name == BaseTy.bool:
            return 'toBool'
        elif t.name == BaseTy.float:
            return 'toDouble'
        elif t.name == BaseTy.str:
            return 'string'

    elif isinstance(t, OptionalType):
        if str(t.elem) == 'Tensor':
            if local.use_c10_dispatcher().dispatcher_uses_new_style():
                return 'optionalTensor'
            else:
                return 'tensor'

        elif isinstance(t.elem, BaseType):
            if t.elem.name in [
                    BaseTy.ScalarType, BaseTy.Scalar, BaseTy.int, BaseTy.bool,
                    BaseTy.float, BaseTy.str
            ]:
                # Regular cases: append 'Optional' to elem's unpacking method
                return arg_parser_unpack_method(t.elem, False) + 'Optional'
            elif t.elem.name == BaseTy.MemoryFormat:
                return 'memoryformatOptional'
            elif t.elem.name == BaseTy.Generator:
                return 'generator'
            elif t.elem.name == BaseTy.Layout:
                return 'layoutWithDefault' if has_default else 'layoutOptional'

        elif isinstance(t.elem, ListType):
            if str(t.elem.elem) == 'int':
                # accept definite size
                return 'intlistOptional'
            elif str(t.elem) == 'float[]':
                return 'doublelistOptional'
            elif str(t.elem) == 'Dimname[]':
                return 'toDimnameListOptional'

    elif isinstance(t, ListType):
        if str(t.elem) == 'Tensor' or str(t.elem) == 'Tensor?':
            # accept and use definite size
            if t.size is not None:
                return f'tensorlist_n<{t.size}>'
            else:
                return 'tensorlist'
        elif str(t.elem) == 'Dimname':
            # accept definite size
            return 'dimnamelist'
        elif str(t.elem) == 'int':
            # accept definite size
            return 'intlist'
        elif str(t) == 'float[]':
            return 'doublelist'

    raise RuntimeError(f'type \'{t}\' is not supported by PythonArgParser')
Пример #29
0
    def go(f: NativeFunction) -> Optional[str]:
        if str(f.func.name.name).endswith('_like') or str(
                f.func.name.name).startswith('new_'):
            return None

        name = native.name(f.func)
        native_sig = NativeSignature.from_schema(f.func)

        if not any(
                isinstance(a.argument, TensorOptionsArguments)
                for a in native_sig.arguments()):
            return None

        native_tensor_args = [
            a for a in native_sig.arguments()
            if isinstance(a.argument, Argument)
            and a.argument.type.is_tensor_like()
        ]

        dispatcher_sig = DispatcherSignature.from_schema(f.func)

        sig: Union[NativeSignature, DispatcherSignature]
        if local.use_c10_dispatcher().dispatcher_uses_new_style():
            sig = dispatcher_sig
            dispatcher_exprs = dispatcher_sig.exprs()
            dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
        else:
            sig = native_sig
            dispatcher_exprs = native_sig.dispatcher_exprs()
            dispatch_key = "options.computeDispatchKey()"

        if target is Target.DEFINITION:
            # I don't think there's actually a good reason to generate
            # these two cases differently
            # The first case could probably be improved though- it calls dispatchTypeId(),
            # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
            if native_tensor_args:
                tensor_args = ', '.join(a.name for a in native_tensor_args)
                compute_dk = f"""\
DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
  DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
  DispatchKey _dk = c10::impl::dispatchTypeId(_dk_set, _dk_mask);"""
            else:
                compute_dk = f"DispatchKey _dk = {dispatch_key};"
            return f"""\
// aten::{f.func}
{sig.defn(name)} {{
  static auto op = c10::Dispatcher::singleton()
    .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
    .typed<{dispatcher_sig.type()}>();
  {compute_dk}
  DispatchKey _autograd_dk = c10::getAutogradKeyFromBackend(_dk);
  // This trick allows calling Autograd backend kernel first and then backend kernel,
  // without adding another AutogradBackendSelect dispatch key.
  DispatchKey _current_dk = at::impl::variable_excluded_from_dispatch() ? _dk : _autograd_dk;
  return op.callWithDispatchKey(_current_dk, {', '.join(a.expr for a in dispatcher_exprs)});
}}
"""
        elif target is Target.REGISTRATION:
            if local.use_c10_dispatcher() is UseC10Dispatcher.full:
                return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
            elif local.use_c10_dispatcher(
            ) is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
                return f"""m.impl("aten::{f.func.name}",
          c10::impl::hacky_wrapper_for_legacy_signatures<{dispatcher_sig.type()}>(
            TORCH_FN({name})));"""
            else:
                assert local.use_c10_dispatcher(
                ) is UseC10Dispatcher.with_codegenerated_unboxing_wrapper
                return f"""m.impl_UNBOXED("aten::{f.func.name}", {name});"""
        elif target is Target.DECLARATION:
            raise AssertionError()
        else:
            assert_never(target)
Пример #30
0
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *,
             is_out: bool) -> List[Binding]:
    # Ideally, we NEVER default native functions.  However, there are a number
    # of functions that call native:: directly and rely on the defaulting
    # existing.  So for BC, we generate defaults for non-out variants (but not
    # for out variants, where it is impossible to generate an appropriate
    # default)
    should_default = not is_out or local.use_c10_dispatcher(
    ) is not UseC10Dispatcher.full
    if isinstance(a, Argument):
        default: Optional[str] = None
        if should_default and a.default is not None:
            default = cpp.default_expr(a.default, a.type)
        return [
            Binding(
                ctype=argument_type(a, binds=a.name),
                name=a.name,
                default=default,
                argument=a,
            )
        ]
    elif isinstance(a, SelfArgument):
        # Erase SelfArgument from the distinction
        return argument(a.argument, is_out=is_out)
    elif isinstance(a, TensorOptionsArguments):
        if local.use_c10_dispatcher() in [
                UseC10Dispatcher.hacky_wrapper_for_legacy_signatures,
                UseC10Dispatcher.with_codegenerated_unboxing_wrapper
        ]:
            # TODO: expunge this logic entirely
            default = None
            if should_default:
                if all(x.default == "None" for x in a.all()):
                    default = '{}'
                elif a.dtype.default == "long":
                    default = 'at::kLong'  # TODO: this is wrong
            return [
                Binding(
                    ctype=ConstRefCType(BaseCType('TensorOptions', 'options')),
                    name='options',
                    default=default,
                    argument=a,
                )
            ]
        else:
            assert local.use_c10_dispatcher() == UseC10Dispatcher.full
            default = None
            if should_default:
                default = '{}'
            # TODO: Not sure why the arguments assigned here are for
            # TensorOptionsArguments and not the constituent pieces.  It seems
            # to matter
            return [
                Binding(
                    ctype=OptionalCType(BaseCType('ScalarType', 'dtype')),
                    name='dtype',
                    default=default,
                    argument=a,
                ),
                Binding(
                    ctype=OptionalCType(BaseCType('Layout', 'layout')),
                    name='layout',
                    default=default,
                    argument=a,
                ),
                Binding(
                    ctype=OptionalCType(BaseCType('Device', 'device')),
                    name='device',
                    default=default,
                    argument=a,
                ),
                Binding(
                    ctype=OptionalCType(BaseCType('bool', 'pin_memory')),
                    name='pin_memory',
                    default=default,
                    argument=a,
                )
            ]
    else:
        assert_never(a)