コード例 #1
0
def gen_structured(g: NativeFunctionsGroup) -> List[str]:
    # only out has dispatch
    meta_name = meta.name(g)
    rs = []
    seen: Set[Any] = set()
    out_args = structured.impl_arguments(g)
    for k, n in g.out.dispatch.items():
        if n in seen:
            continue
        if not is_structured_dispatch_key(k):
            continue
        seen.add(n)
        rs.append(f"""\
struct TORCH_API structured_{n} : public at::meta::{meta_name} {{
void impl({', '.join(a.decl() for a in out_args)});
}};
""")

    seen = set()
    for f in g.functions():
        returns_type = native.returns_type(f.func.returns)
        args = native.arguments(f.func)
        for k, n in f.dispatch.items():
            if n in seen:
                continue
            if is_structured_dispatch_key(k):
                continue
            seen.add(n)
            args_str = ', '.join(a.decl() for a in args)
            rs.append(f"TORCH_API {returns_type} {n}({args_str});")

    return rs
コード例 #2
0
def compute_meta_function_declaration(g: StructuredNativeFunctions) -> str:
    with native_function_manager(g.out):
        sig = g.signature()
        name = meta.name(sig)
        returns_type = meta.returns_type(sig.returns)
        args = meta.arguments(sig)
        return f"CAFFE2_API {returns_type} {name}({', '.join(map(str, args))});"
コード例 #3
0
def compute_meta_function_declaration(g: StructuredNativeFunctions) -> str:
    with native_function_manager(g.out):
        name = meta.name(g)
        args = structured.meta_arguments(g)
        args_str = ', '.join(a.decl() for a in args)
        parent_class = g.out.structured_inherits
        if parent_class is None:
            parent_class = "at::impl::MetaBase"
        return f"""\
コード例 #4
0
def compute_native_function_declaration(
        g: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]:
    if isinstance(g, StructuredNativeFunctions):
        # only out has dispatch
        meta_name = meta.name(g)
        rs = []
        seen: Set[Any] = set()
        out_args = structured.impl_arguments(g)
        for k, n in g.out.dispatch.items():
            if n in seen:
                continue
            if not is_structured_dispatch_key(k):
                continue
            seen.add(n)
            rs.append(f"""\
struct TORCH_API structured_{n} : public at::meta::{meta_name} {{
    void impl({', '.join(a.decl() for a in out_args)});
}};
""")

        seen = set()
        for f in g.functions():
            returns_type = native.returns_type(f.func.returns)
            args = native.arguments(f.func)
            for k, n in f.dispatch.items():
                if n in seen:
                    continue
                if is_structured_dispatch_key(k):
                    continue
                seen.add(n)
                args_str = ', '.join(a.decl() for a in args)
                rs.append(f"TORCH_API {returns_type} {n}({args_str});")

        return rs

    else:
        f = g
        ns = list(f.dispatch.values())

        rs = []
        # Sometimes a function name shows up multiple times; only generate
        # it once!
        seen = set()
        for n in ns:
            if n in seen:
                continue
            if "legacy::" in n:
                continue
            seen.add(n)
            returns_type = native.returns_type(f.func.returns)
            args = native.arguments(f.func)
            rs.append(
                f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});"
            )

        return rs
コード例 #5
0
ファイル: native_functions.py プロジェクト: sbrodehl/pytorch
def gen_structured(g: NativeFunctionsGroup,
                   backend_index: BackendIndex) -> List[str]:
    meta_name = meta.name(g)
    out_args = structured.impl_arguments(g)
    metadata = backend_index.get_kernel(g)
    if metadata is None:
        return []
    prefix = '' if backend_index.external else 'TORCH_API '
    return [
        f"""\
struct {prefix}structured_{metadata.kernel} : public at::meta::{meta_name} {{
void impl({', '.join(a.decl() for a in out_args)});
}};
"""
    ]
コード例 #6
0
        def gen_one(f: NativeFunction) -> Optional[str]:
            assert self.target is not Target.DECLARATION

            # TODO: put this into StructuredNativeFunctions itself
            functional_func = g.out.func.signature()
            functional_sig = DispatcherSignature.from_schema(functional_func)
            meta_name = meta.name(functional_func)

            # This is a little abusive; this assumes that the functionalization
            # transformation ALWAYS refers to valid arguments in the original
            # signature
            functional_exprs = ', '.join(e.expr
                                         for e in functional_sig.exprs())

            op_name = f"aten::{f.func.name}"
            if self.target is Target.REGISTRATION and not self.selector.is_operator_selected(
                    op_name):
                return None

            k = f.func.kind()
            sig = NativeSignature.from_schema(f.func)

            if self.target is Target.DEFINITION:
                out_impl_name = f"at::native::{g.out.dispatch[self.dispatch_key]}"

                # TODO: work a little harder to generate fresh names for 'result'
                # TODO: less praying that I picked the right argument name for 'self'

                if k is SchemaKind.functional:
                    out_expr = "result"
                    prologue = "auto result = tensor_from_meta(meta_result);"
                elif k is SchemaKind.inplace:
                    out_expr = "self"
                    prologue = "// TODO: consistency check assert"
                elif k is SchemaKind.out:
                    # TODO: generalize this for multi-out
                    assert len(f.func.out_arguments
                               ) == 1, "multi-out structured not supported yet"
                    # TODO: properly get the expression as it was brought into
                    # scope by sig
                    out_expr = f.func.out_arguments[0].name
                    prologue = f"""
// TODO: add a consistency check for meta_result
{out_expr}.resize_(meta_result.sizes);
"""

                device_guard = ""

                if is_generic_dispatch_key(
                        self.dispatch_key) or is_cuda_dispatch_key(
                            self.dispatch_key):
                    # TODO: avoid copypasting the computation of self_args,
                    # candidate_args and device_of
                    self_args = (a for a in f.func.arguments
                                 if a.name == "self")
                    candidate_args = itertools.chain(self_args,
                                                     f.func.out_arguments,
                                                     f.func.arguments)
                    device_of = next(
                        (f'{a.name}'
                         for a in candidate_args if a.type.is_tensor_like()),
                        None)

                    device_guard = ''
                    if f.device_guard and device_of is not None:
                        # TODO: Use OptionalCUDAGuard when possible
                        device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"
                    # TODO: figure out what to do about structured kernels and
                    # factory functions

                # For an overview of what this template code looks like, see
                # https://github.com/pytorch/rfcs/pull/9
                return f"""\
{sig.defn()} {{
    {device_guard}
    auto meta_result = meta::{meta_name}({functional_exprs});
    {prologue}
    {out_impl_name}({out_expr}, {functional_exprs});
    return {out_expr};
}}
"""

            elif self.target is Target.REGISTRATION:
                if local.use_c10_dispatcher() is UseC10Dispatcher.full:
                    payload = f'TORCH_FN({sig.name()})'
                else:
                    payload = f'torch::CppFunction::makeUnboxedOnly({sig.name()})'
                return f'm.impl("{f.func.name}", {payload});'
            else:
                assert_never(self.target)