Exemplo n.º 1
0
def method_registration(f: NativeFunction) -> Optional[str]:
    if cpp.name(f.func) in MANUAL_TRACER:
        return None

    return WRAPPER_REGISTRATION.substitute(
        name=f.func.name,
        type_wrapper_name=type_wrapper_name(f),
        class_type='TraceType',
    )
Exemplo n.º 2
0
def gen_variable_type(
    out: str,
    native_yaml_path: str,
    differentiability_infos: Sequence[DifferentiabilityInfo],
    template_path: str,
    operator_selector: SelectiveBuilder,
) -> None:
    """VariableType.h and VariableType.cpp body

    This is the at::Type subclass for differentiable tensors. The
    implementation of each function dispatches to the base tensor type to
    compute the output. The grad_fn is attached to differentiable functions.
    """
    fns = list(
        sorted(filter(
            operator_selector.is_native_function_selected_for_training,
            parse_native_yaml(native_yaml_path)),
               key=lambda f: cpp.name(f.func)))
    fns_with_infos = match_differentiability_info(fns, differentiability_infos)

    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    gen_variable_type_shard(fm, fns_with_infos, 'VariableType.h',
                            'VariableType.h')

    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    num_shards = 5
    shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [
        [] for _ in range(num_shards)
    ]

    # functions are assigned arbitrarily but stably to a file based on hash
    for fn in fns_with_infos:
        x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards
        shards[x].append(fn)

    for i, shard in enumerate(shards):
        gen_variable_type_shard(fm, shard, 'VariableType.cpp',
                                f'VariableType_{i}.cpp')

    gen_variable_type_shard(fm, fns_with_infos, 'VariableType.cpp',
                            'VariableTypeEverything.cpp')
Exemplo n.º 3
0
    def __call__(self, f: NativeFunction) -> Optional[str]:
        if Variant.method not in f.variants:
            return None

        assert not f.func.is_out_fn()
        assert f.func.arguments.self_arg is not None

        name = cpp.name(f.func)

        sig_group = CppSignatureGroup.from_native_function(f, method=True, fallback_binding=f.manual_cpp_binding)

        if self.target is Target.DECLARATION:
            result = f"{sig_group.signature.decl()} const;\n"
            if sig_group.faithful_signature is not None:
                result += f"{sig_group.faithful_signature.decl()} const;\n"
            return result

        if self.target is not Target.DEFINITION:
            assert_never(self.target)

        def generate_defn(faithful: bool) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            if faithful:
                sig = sig_group.faithful_signature
                assert sig is not None
            else:
                sig = sig_group.signature

            dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments(), method=True)
            dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs)

            static_dispatch_block = static_dispatch(f, sig, method=True, backend_index=self.static_dispatch_backend_index)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
{sig.defn(prefix="Tensor::")} const {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{dispatcher_sig.type()}>();
    return op.call({dispatcher_exprs_str});
}}
"""
            else:
                return f"""
// aten::{f.func}
{sig.defn(prefix="Tensor::")} const {{
    {static_dispatch_block}
}}
"""

        result = generate_defn(faithful=False)
        if sig_group.faithful_signature is not None:
            result += generate_defn(faithful=True)

        return result
Exemplo n.º 4
0
    def callImpl(self, f: NativeFunction) -> str:
        name = cpp.name(f.func)

        sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding)

        if self.target is Target.DECLARATION:
            sig_str = sig_group.signature.decl(is_redispatching_fn=self.is_redispatching_fn)
            result = f"TORCH_API {sig_str};\n"
            if sig_group.faithful_signature is not None:
                sig_str = sig_group.faithful_signature.decl(is_redispatching_fn=self.is_redispatching_fn)
                result += f"TORCH_API {sig_str};\n"
            return result

        if self.target is not Target.DEFINITION:
            assert_never(self.target)

        def generate_defn(faithful: bool) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            if faithful and sig_group.faithful_signature is not None:
                sig = sig_group.faithful_signature
            else:
                sig = sig_group.signature

            dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments())
            if self.is_redispatching_fn:
                dispatcher_exprs_str = ', '.join(['dispatchKeySet'] + [a.expr for a in dispatcher_exprs])
                dispatcher_call = 'redispatch'
            else:
                dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs)
                dispatcher_call = 'call'

            static_dispatch_block = static_dispatch(f, sig, method=False, backend_index=self.static_dispatch_backend_index)
            if static_dispatch_block is None:
                return f"""
// aten::{f.func}
{sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{dispatcher_sig.type()}>();
    return op.{dispatcher_call}({dispatcher_exprs_str});
}}
"""
            else:
                return f"""
// aten::{f.func}
{sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{
    {static_dispatch_block}
}}
"""
        result = generate_defn(sig_group.faithful_signature is None)
        if sig_group.faithful_signature is not None:
            result += generate_defn(True)

        return result
Exemplo n.º 5
0
def gen_variable_type_shard(
    fm: FileManager,
    fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
    template_name: str,
    output_name: str,
) -> None:
    type_declarations: List[str] = []
    type_definitions: List[str] = []
    wrapper_registrations: List[str] = []

    for fn in fns_with_infos:
        f = fn.func
        name = cpp.name(f.func)
        formals = gen_formals(f)

        type_declarations.append(
            METHOD_DECLARATION.substitute(
                return_type=cpp.returns_type(f.func.returns),
                type_wrapper_name=type_wrapper_name(f),
                formals=formals,
            ))

        if name not in MANUAL_AUTOGRAD and dispatch_strategy(
                fn) == 'use_derived':
            type_definitions.append(
                METHOD_DEFINITION.substitute(
                    return_type=cpp.returns_type(f.func.returns),
                    type_wrapper_name=type_wrapper_name(f),
                    type_definition_body=emit_body(fn),
                    formals=formals,
                ))
            wrapper_registrations.append(gen_wrapper_registration(f))

        # See Note [Manual Backend kernels]
        assert (name in MANUAL_BACKEND) == f.manual_kernel_registration
        # If you want to register a kernel to Autograd, you must make the op abstract.
        # In other words, this op must have dispatch section in native_functions.yaml.
        if name in MANUAL_AUTOGRAD_AND_TRACER or (fn.info
                                                  and fn.info.has_derivatives):
            msg = (
                f'There\'s a formula for {name}(or its functional variant) in derivatives.yaml. '
                f'It\'s required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA '
                f'or DefaultBackend in native_functions.yaml. Please see '
                f'https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword '
                f'for instructions to choose the right dispatch keyword.')
            assert f.is_abstract, msg

    fm.write_with_template(
        output_name, template_name, lambda: {
            'generated_comment': '@' +
            f'generated from {fm.template_dir}/{template_name}',
            'type_derived_method_declarations': type_declarations,
            'type_derived_method_definitions': type_definitions,
            'wrapper_registrations': wrapper_registrations,
        })
def emit_view_lambda(f: NativeFunction,
                     unpacked_bindings: List[Binding]) -> str:
    """ Generate an additional lambda function to recover views in backward when as_strided is not supported.
    See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details."""
    input_base = 'input_base'
    replay_view_func = ''
    updated_unpacked_args: List[str] = []
    known_view_arg_simple_types: List[CType] = [
        BaseCType(intT),
        OptionalCType(BaseCType(intT)),
        BaseCType(boolT),
        BaseCType(intArrayRefT)
    ]
    for unpacked_binding in unpacked_bindings:
        arg, arg_type = unpacked_binding.name, unpacked_binding.nctype.type
        if arg == 'self_':
            updated_unpacked_args.append(input_base)
            continue
        if arg_type not in known_view_arg_simple_types:
            known_types_str = ', '.join(
                [str(t) for t in known_view_arg_simple_types])
            raise TypeError(
                f'You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: '
                f'{known_types_str}. Please update the list or materialize it so that it can be closed '
                'over by value, also add a test in pytorch/xla/test/test_operations.py where this code '
                'is exercised.')

        if arg_type == BaseCType(intArrayRefT):
            # It's not safe to close over IntArrayRef by value, since this is a
            # reference type, so materialize a vector to close over by value
            arg_vec = arg + '_vec'
            replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg,
                                                           vec=arg_vec)
            updated_unpacked_args.append(arg_vec)
        elif arg_type == OptionalCType(BaseCType(intT)):
            # Materialize int64_t? to int64_t
            arg_value = arg + '_val'
            replay_view_func += OPTIONAL_TO_VAL.substitute(arg=arg,
                                                           val=arg_value,
                                                           default='0')
            updated_unpacked_args.append(arg_value)
        else:
            updated_unpacked_args.append(arg)

    replay_view_call = emit_view_call(f, input_base, updated_unpacked_args)
    replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute(
        input_base=input_base, replay_view_call=replay_view_call)

    is_view_with_metadata_change = 'true' if cpp.name(
        f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else 'false'

    return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute(
        is_view_with_metadata_change=is_view_with_metadata_change,
        replay_view_func=replay_view_func)
Exemplo n.º 7
0
def cpp_dispatch_target(f: NativeFunction) -> str:
    name = cpp.name(f.func)
    if Variant.method in f.variants:
        return f'self.{name}'
    if Variant.function in f.variants:
        if has_tensor_options(f) or f.func.name.name.base.endswith('_like'):
            namespace = 'torch'
        else:
            namespace = 'at'
        return f'{namespace}::{name}'
    raise RuntimeError(f'could not dispatch, neither function nor method: {f.func}')
Exemplo n.º 8
0
def should_generate_py_binding(f: NativeFunction) -> bool:
    name = cpp.name(f.func)
    for skip_regex in SKIP_PYTHON_BINDINGS:
        if skip_regex.match(name):
            return False

    signature = str(f.func)
    for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
        if pattern == signature:
            return False

    return True
Exemplo n.º 9
0
def gen_trace_type(out: str, native_yaml_path: str,
                   template_path: str) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    num_shards = 5
    shards: List[List[NativeFunction]] = [[] for _ in range(num_shards)]

    # functions are assigned arbitrarily but stably to a file based on hash
    native_functions = list(
        sorted(parse_native_yaml(native_yaml_path),
               key=lambda f: cpp.name(f.func)))
    for f in native_functions:
        x = sum(ord(c) for c in cpp.name(f.func)) % num_shards
        shards[x].append(f)

    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    for i, shard in enumerate(shards):
        gen_trace_type_shard(fm, shard, '_%d' % i)
    gen_trace_type_shard(fm, native_functions, 'Everything')
Exemplo n.º 10
0
def gen_trace_type(out: str, native_functions: List[NativeFunction],
                   template_path: str) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_sharded('TraceType.cpp', [
        fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER
    ],
                     key_fn=lambda fn: cpp.name(fn.func),
                     base_env={
                         'generated_comment':
                         f'@generated from {template_path}/TraceType.cpp',
                     },
                     env_callable=gen_trace_type_func,
                     num_shards=5,
                     sharded_keys={
                         'trace_method_definitions',
                         'trace_wrapper_registrations'
                     })
Exemplo n.º 11
0
 def enforce_same_tensorimpl_and_storage(
         call: str, unpacked_bindings: List[Binding]) -> str:
     save_ptrs_stmts: List[str] = []
     enforce_same_ptrs_stmts: List[str] = []
     if cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
         for unpacked_binding in unpacked_bindings:
             arg = unpacked_binding.name
             noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref(
             )
             if noref_cpp_type == BaseCType(tensorListT):
                 save_ptrs_stmts += [
                     SAVE_TENSORLIST_STORAGE.substitute(
                         tensorlist_name=arg),
                     SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)
                 ]
                 enforce_same_ptrs_stmts += [
                     ENFORCE_SAME_TENSORLIST_STORAGE.substitute(
                         tensorlist_name=arg),
                     ENFORCE_SAME_TENSORLIST_IMPL.substitute(
                         tensorlist_name=arg)
                 ]
             elif noref_cpp_type == ListCType(
                     OptionalCType(BaseCType(tensorT))):
                 save_ptrs_stmts += [
                     SAVE_OPTIONALTENSORLIST_STORAGE.substitute(
                         tensorlist_name=arg),
                     SAVE_OPTIONALTENSORLIST_IMPL.substitute(
                         tensorlist_name=arg)
                 ]
                 enforce_same_ptrs_stmts += [
                     ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(
                         tensorlist_name=arg),
                     ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(
                         tensorlist_name=arg)
                 ]
             elif noref_cpp_type == BaseCType(tensorT):
                 save_ptrs_stmts += [
                     SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
                     SAVE_TENSOR_IMPL.substitute(tensor_name=arg)
                 ]
                 enforce_same_ptrs_stmts += [
                     ENFORCE_SAME_TENSOR_STORAGE.substitute(
                         tensor_name=arg),
                     ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)
                 ]
     assert (save_ptrs_stmts and enforce_same_ptrs_stmts) or (
         not save_ptrs_stmts and not enforce_same_ptrs_stmts)
     if save_ptrs_stmts and enforce_same_ptrs_stmts:
         call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=save_ptrs_stmts) + \
             call + \
             RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts)
     return call
Exemplo n.º 12
0
def generate_out_variant_call(g: NativeFunctionsGroup) -> str:
    schema = g.out.func
    assert schema.is_out_fn()
    arg_names = [out_arg.name for out_arg in schema.arguments.out]
    for arg in schema.arguments.non_out:
        if isinstance(arg, SelfArgument):
            arg_names.append(arg.argument.name)
        else:
            assert isinstance(arg, Argument)
            arg_names.append(arg.name)
    cpp_func_name = cpp.name(schema)
    cpp_arg_names = ",".join(arg_names)
    return f'at::cpu::{cpp_func_name}({cpp_arg_names})'
Exemplo n.º 13
0
 def enforce_same_tensorimpl_and_storage(
         call: str, unpacked_bindings: List[Binding]) -> str:
     save_ptrs_stmts: List[str] = []
     enforce_same_ptrs_stmts: List[str] = []
     if cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
         for unpacked_binding in unpacked_bindings:
             arg = unpacked_binding.name
             noref_cpp_type = unpacked_binding.ctype.cpp_type(
                 strip_ref=True)
             if noref_cpp_type == 'TensorList':
                 save_ptrs_stmts += [
                     SAVE_TENSORLIST_STORAGE.substitute(
                         tensorlist_name=arg),
                     SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)
                 ]
                 enforce_same_ptrs_stmts += [
                     ENFORCE_SAME_TENSORLIST_STORAGE.substitute(
                         tensorlist_name=arg),
                     ENFORCE_SAME_TENSORLIST_IMPL.substitute(
                         tensorlist_name=arg)
                 ]
             elif noref_cpp_type == 'c10::List<c10::optional<Tensor>>':
                 save_ptrs_stmts += [
                     SAVE_OPTIONALTENSORLIST_STORAGE.substitute(
                         tensorlist_name=arg),
                     SAVE_OPTIONALTENSORLIST_IMPL.substitute(
                         tensorlist_name=arg)
                 ]
                 enforce_same_ptrs_stmts += [
                     ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(
                         tensorlist_name=arg),
                     ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(
                         tensorlist_name=arg)
                 ]
             elif noref_cpp_type == 'Tensor':
                 save_ptrs_stmts += [
                     SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
                     SAVE_TENSOR_IMPL.substitute(tensor_name=arg)
                 ]
                 enforce_same_ptrs_stmts += [
                     ENFORCE_SAME_TENSOR_STORAGE.substitute(
                         tensor_name=arg),
                     ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)
                 ]
     assert (save_ptrs_stmts and enforce_same_ptrs_stmts) or (
         not save_ptrs_stmts and not enforce_same_ptrs_stmts)
     if save_ptrs_stmts and enforce_same_ptrs_stmts:
         call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=save_ptrs_stmts) + \
             call + \
             RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts)
     return call
Exemplo n.º 14
0
def should_generate_py_binding(f: NativeFunction) -> bool:
    name = cpp.name(f.func)
    for pattern in SKIP_PYTHON_BINDINGS:
        if re.match('^' + pattern + '$', name):
            return False

    args = ', '.join(
        argument_type_str(arg.type) for arg in signature(f).arguments())
    sig = f'{name}({args})'
    for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
        if pattern == sig:
            return False

    return True
Exemplo n.º 15
0
def method_definition(f: NativeFunction) -> Optional[str]:
    if cpp.name(f.func) in MANUAL_TRACER:
        return None

    formals = ', '.join(
        f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
        for a in f.func.schema_order_arguments())

    return METHOD_DEFINITION.substitute(
        return_type=cpp.returns_type(f.func.returns),
        type_wrapper_name=type_wrapper_name(f),
        formals=formals,
        type_definition_body=emit_trace_body(f),
    )
Exemplo n.º 16
0
    def go(f: NativeFunction) -> str:
        # header comments
        deprecated = '[deprecated] ' if ps.deprecated else ''
        schema_comment = f'// {deprecated}aten::{f.func}'

        # dispatch lambda signature
        name = cpp.name(f.func)
        lambda_formals = ', '.join(
            map(lambda a: f"{a.type_str} {a.name}",
                dispatch_lambda_args(ps, f)))
        lambda_return = dispatch_lambda_return_str(f)

        # dispatch lambda body
        dispatch_callee = cpp_dispatch_target(f)
        dispatch_args = ', '.join(cpp_dispatch_exprs(f, python_signature=ps))

        # from arg parser outputs to dispatch lambda arguments
        parser_outputs = arg_parser_output_exprs(ps, f)
        lambda_arg_exprs = dispatch_lambda_exprs(ps, f)
        inits = '\n'.join(lambda_arg_exprs.inits)
        lambda_args = ', '.join(lambda_arg_exprs.exprs)

        # scatter fields
        # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky
        #       solution for enabling the 'requires_grad' argument for tensor methods
        #       new_full, new_empty, and new_zeros. A much better but more difficult to
        #       implement solution involves refactoring according to Ed's description here:
        #       https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589
        need_set_requires_grad = ps.tensor_options_args and (
            not has_tensor_options(f) or (ps.method and
                                          ('requires_grad' in parser_outputs)))
        set_requires_grad = f'.set_requires_grad({parser_outputs["requires_grad"].expr})' \
            if need_set_requires_grad else ''

        if lambda_return == 'void':
            return f"""\
{schema_comment}
{inits}
auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
  pybind11::gil_scoped_release no_gil;
  {dispatch_callee}({dispatch_args});
}};
dispatch_{name}({lambda_args}){set_requires_grad};
Py_RETURN_NONE;
"""
        else:
            typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f))
            namedtuple_typeref = f'&{typename}, ' if typename is not None else ''
            return f"""\
Exemplo n.º 17
0
    def emit_dispatch_call(f: NativeFunction, input_base: str, unpacked_args: Sequence[str]) -> str:
        """ Dispatch call via function in a namespace or method on Tensor."""
        dispatcher_sig = DispatcherSignature.from_schema(f.func)
        dispatcher_exprs = dispatcher_sig.exprs()

        # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance.
        # Ops also always have a function variant of the redispatch API.
        # See Note [Plumbing Keys Through The Dispatcher] for details.
        dispatch_key_set = 'ks & c10::after_autograd_keyset'
        call = CALL_REDISPATCH.substitute(
            api_name=cpp.name(
                f.func,
                faithful_name_for_out_overloads=True,
            ),
            unpacked_args=[dispatch_key_set] + list(unpacked_args))
        return call
Exemplo n.º 18
0
def method_registration(f: NativeFunction) -> Optional[str]:
    if cpp.name(f.func) in MANUAL_TRACER:
        return None

    if f.use_c10_dispatcher.dispatcher_uses_new_style():
        return WRAPPER_REGISTRATION.substitute(
            name=f.func.name,
            type_wrapper_name=type_wrapper_name(f),
            class_type='TraceType',
        )
    else:
        return UNBOXEDONLY_WRAPPER_REGISTRATION.substitute(
            name=f.func.name,
            type_wrapper_name=type_wrapper_name(f),
            class_type='TraceType',
        )
Exemplo n.º 19
0
def method_definition(f: NativeFunction) -> Optional[str]:
    if cpp.name(f.func) in MANUAL_TRACER:
        return None

    if f.use_c10_dispatcher.dispatcher_uses_new_style():
        formals = ', '.join(f'{cpp.argument_type(a)} {a.name}' for a in f.func.schema_order_arguments())
    else:
        sig_group = CppSignatureGroup.from_schema(f.func, method=False)
        formals = ', '.join(f'{a.type} {a.name}' for a in sig_group.signature.arguments())

    return METHOD_DEFINITION.substitute(
        return_type=cpp.returns_type(f.func.returns),
        type_wrapper_name=type_wrapper_name(f),
        formals=formals,
        type_definition_body=emit_trace_body(f),
    )
Exemplo n.º 20
0
def emit_namedtuple_typedefs(
    overloads: Sequence[PythonSignatureNativeFunctionPair]
) -> Tuple[List[str], Dict[str, str]]:
    """
    Generate block of named tuple type def inits, and add typeref snippets
    to declarations that use them
    """
    flddefnames: Dict[str, str] = {
    }  # map from unique field name lists to field def name
    flddefs: List[str] = []  # field def declarations
    typenames: Dict[str, str] = {
    }  # map from unique name + field name lists to typedef name
    typedefs: List[str] = []  # typedef declarations and init code

    for overload in overloads:
        fieldnames = namedtuple_fieldnames(overload.function.func.returns)
        if not fieldnames:
            continue

        fn_key = '_'.join(fieldnames)
        fieldsname = flddefnames.get(fn_key)
        if fieldsname is None:
            fieldsname = f'NamedTuple_fields{"" if not flddefs else len(flddefs)}'
            flddefnames[fn_key] = fieldsname
            fields = ', '.join(f'{{"{fn}", ""}}' for fn in fieldnames)
            flddefs.append(f"""\
static PyStructSequence_Field {fieldsname}[] = {{ {fields},  {{nullptr}} }};
""")

        name = cpp.name(overload.function.func)  # use @with_native_function?
        tn_key = gen_namedtuple_typename_key(overload.function)
        typename = typenames.get(tn_key)
        if typename is None:
            typename = f'NamedTuple{"" if not typedefs else len(typedefs)}'
            typenames[tn_key] = typename
            typedefs.append(f"""\
static PyTypeObject {typename};
static bool {typename}_initialized = false;
if (!{typename}_initialized) {{
  {typename}_initialized = true;
  static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, {fieldsname}, {len(fieldnames)} }};
  PyStructSequence_InitType(&{typename}, &desc);
  {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
}}
""")

    return flddefs + typedefs, typenames
Exemplo n.º 21
0
def method_definition(f: NativeFunction) -> str:
    assert cpp.name(f.func) not in MANUAL_TRACER

    formals = ', '.join(
        # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
        # See Note [Plumbing Keys Through The Dispatcher] for details.
        ['c10::DispatchKeySet ks'] + [
            f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
            for a in f.func.schema_order_arguments()
        ])

    return METHOD_DEFINITION.substitute(
        return_type=cpp.returns_type(f.func.returns).cpp_type(),
        type_wrapper_name=type_wrapper_name(f),
        formals=formals,
        type_definition_body=emit_trace_body(f),
    )
Exemplo n.º 22
0
def generate_return_type_definition_and_map_entry(
    overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], List[str]]:
    """
    Generate block of function in `python_return_types.cpp` to initialize
    and return named tuple for a native function which returns named tuple
    and relevant entry for the map in same file.
    """
    typenames: Dict[str, str] = {
    }  # map from unique name + field name lists to typedef name
    definitions: List[str] = []  # function defintion to register the typedef
    map_entries: List[str] = [
    ]  # C++ map entry of <function_name, function creates it namedtuple>

    for overload in overloads:
        fieldnames = namedtuple_fieldnames(overload.function.func.returns)
        if not fieldnames:
            continue

        fields = ', '.join(f'{{"{fn}", ""}}' for fn in fieldnames)

        name = cpp.name(overload.function.func)  # use @with_native_function?
        tn_key = gen_namedtuple_typename_key(overload.function)
        typename = typenames.get(tn_key)

        if typename is None:
            typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}'
            typenames[tn_key] = typename
            definitions.append(f"""\
PyTypeObject* get_{name}_namedtuple() {{
    static PyStructSequence_Field NamedTuple_fields[] = {{ {fields},  {{nullptr}} }};
    static PyTypeObject {typename};
    static bool is_initialized = false;
    static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }};
    if (!is_initialized) {{
        PyStructSequence_InitType(&{typename}, &desc);
        {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
        is_initialized = true;
    }}
    return &{typename};
}}
""")
            map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ')

    return definitions, map_entries
Exemplo n.º 23
0
    def __call__(self, f: NativeFunction) -> Optional[str]:
        if Variant.function not in f.variants:
            return None

        name = cpp.name(f.func)

        sig_group = CppSignatureGroup.from_native_function(
            f, method=False, fallback_binding=f.manual_cpp_binding)

        if self.target is Target.DECLARATION:
            result = f"TORCH_API {sig_group.signature.decl()};\n"
            if sig_group.faithful_signature is not None:
                result += f"TORCH_API {sig_group.faithful_signature.decl()};\n"
            return result

        if self.target is not Target.DEFINITION:
            assert_never(self.target)

        def generate_defn(faithful: bool) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            if faithful and sig_group.faithful_signature is not None:
                sig = sig_group.faithful_signature
            else:
                sig = sig_group.signature

            dispatcher_exprs = translate(sig.arguments(),
                                         dispatcher_sig.arguments())
            dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs)

            return f"""
// aten::{f.func}
{sig.defn()} {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{dispatcher_sig.type()}>();
    return op.call({dispatcher_exprs_str});
}}
"""

        result = generate_defn(sig_group.faithful_signature is None)
        if sig_group.faithful_signature is not None:
            result += generate_defn(True)

        return result
Exemplo n.º 24
0
    def go(f: NativeFunction) -> Optional[str]:
        if Variant.method not in f.variants:
            return None

        assert not f.func.is_out_fn()
        assert len(f.func.arguments) > 0
        assert sum(a.name == 'self' for a in f.func.arguments) == 1

        name = cpp.name(f.func)

        sig_group = CppSignatureGroup.from_schema(f.func, method=True)

        if target is Target.DECLARATION:
            result = f"{sig_group.signature.decl()} const;\n"
            if sig_group.faithful_signature is not None:
                result += f"{sig_group.faithful_signature.decl()} const;\n"
            return result

        assert target is Target.DEFINITION

        def generate_defn(sig: CppSignature) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            dispatcher_exprs = dispatcher.cpparguments_exprs(
                sig.argument_packs())
            dispatcher_exprs_str = ', '.join(
                map(lambda a: a.expr, dispatcher_exprs))

            return f"""
// aten::{f.func}
{sig.defn(prefix="Tensor::")} const {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{dispatcher_sig.type()}>();
    return op.call({dispatcher_exprs_str});
}}
"""

        result = generate_defn(sig_group.signature)
        if sig_group.faithful_signature is not None:
            result += generate_defn(sig_group.faithful_signature)

        return result
Exemplo n.º 25
0
    def go(f: NativeFunction) -> Optional[str]:
        if f.manual_kernel_registration:
            return None
        if Variant.function not in f.variants:
            return None

        name = cpp.name(f.func)

        sig_group = CppSignatureGroup.from_schema(f.func, method=False)

        if target is Target.DECLARATION:
            result = f"CAFFE2_API {sig_group.signature.decl()};\n"
            if sig_group.faithful_signature is not None:
                result += f"CAFFE2_API {sig_group.faithful_signature.decl()};\n"
            return result

        assert target is Target.DEFINITION

        def generate_defn(sig: CppSignature) -> str:
            dispatcher_sig = DispatcherSignature.from_schema(f.func)

            dispatcher_exprs = dispatcher.cpparguments_exprs(
                sig.argument_packs())
            dispatcher_exprs_str = ', '.join(
                map(lambda a: a.expr, dispatcher_exprs))

            return f"""
// aten::{f.func}
{sig.defn()} {{
    static auto op = c10::Dispatcher::singleton()
        .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}")
        .typed<{dispatcher_sig.type()}>();
    return op.call({dispatcher_exprs_str});
}}
"""

        result = generate_defn(sig_group.signature)
        if sig_group.faithful_signature is not None:
            if local.use_c10_dispatcher().dispatcher_uses_new_style():
                result += generate_defn(sig_group.faithful_signature)

        return result
Exemplo n.º 26
0
def gen_inplace_or_view_type(
    out: str,
    native_yaml_path: str,
    fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
    template_path: str
) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    num_shards = 2
    shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [[] for _ in range(num_shards)]

    # functions are assigned arbitrarily but stably to a file based on hash
    for fn in fns_with_infos:
        x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards
        shards[x].append(fn)

    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    for i, shard in enumerate(shards):
        gen_inplace_or_view_type_shard(fm, shard, f'_{i}')
    gen_inplace_or_view_type_shard(fm, fns_with_infos, 'Everything')
Exemplo n.º 27
0
def gen_autograd(
    aten_path: str,
    native_functions_path: str,
    out: str,
    autograd_dir: str,
    operator_selector: SelectiveBuilder,
    disable_autograd: bool = False,
) -> None:
    # Parse and load derivatives.yaml
    from .load_derivatives import load_derivatives
    differentiability_infos = load_derivatives(
        os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)

    template_path = os.path.join(autograd_dir, 'templates')

    fns = list(
        sorted(filter(
            operator_selector.is_native_function_selected_for_training,
            parse_native_yaml(native_functions_path)),
               key=lambda f: cpp.name(f.func)))
    fns_with_diff_infos: List[
        NativeFunctionWithDifferentiabilityInfo] = match_differentiability_info(
            fns, differentiability_infos)

    # Generate VariableType.h/cpp
    from .gen_trace_type import gen_trace_type
    from .gen_variable_type import gen_variable_type
    if not disable_autograd:
        gen_variable_type(out, native_functions_path, fns_with_diff_infos,
                          template_path)

        # operator filter not applied as tracing sources are excluded in selective build
        gen_trace_type(out, native_functions_path, template_path)

    # Generate Functions.h/cpp
    from .gen_autograd_functions import gen_autograd_functions_lib
    gen_autograd_functions_lib(out, differentiability_infos, template_path)

    # Generate variable_factories.h
    from .gen_variable_factories import gen_variable_factories
    gen_variable_factories(out, native_functions_path, template_path)
Exemplo n.º 28
0
def format_trace_op_name(f: NativeFunction) -> str:
    # TODO: byte-for-byte compatible with old codegen behavior - should clean up
    if f.func.kind() in (SchemaKind.functional, SchemaKind.out) or f.func.name.name.dunder_method:
        # special case for *_out functions: the in-place and out-of-place ops
        # are overloaded with the same name in the JIT
        trace_name = str(f.func.name.name)
        trace_name = RENAME_TRACE.get(trace_name, trace_name)
        return OP_NAME.substitute(trace_name=trace_name)

    # otherwise, this is an in-place op and we need to emit both in- and
    # out-of-place versions
    outplace_trace_name = f.func.name.name.base
    inplace_trace_name = cpp.name(f.func)
    outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name)
    inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name)

    return SELECT.substitute(
        cond='tracer_state->force_outplace',
        true=OP_NAME.substitute(trace_name=outplace_trace_name),
        false=OP_NAME.substitute(trace_name=inplace_trace_name),
    )
Exemplo n.º 29
0
def compute_native_function_declaration(f: NativeFunction) -> List[str]:
    if f.dispatch is None:
        ns = [cpp.name(f.func)]
    else:
        ns = list(f.dispatch.values())

    rs = []
    # Sometimes a function name shows up multiple times; only generate
    # it once!
    seen = set()
    for n in ns:
        if n in seen:
            continue
        if "legacy::" in n:
            continue
        seen.add(n)
        returns_type = legacy_dispatcher.returns_type(f.func.returns)
        args = legacy_dispatcher.arguments(f.func)
        rs.append(f"CAFFE2_API {returns_type} {n}({', '.join(map(lambda a: a.str_with_default(), args))});")

    return rs
Exemplo n.º 30
0
def format_prerecord_trace(f: NativeFunction) -> str:
    if not should_trace(f):
        return ''

    # TODO: clean up old codegen behavior
    is_inplace = f.func.kind() in (SchemaKind.inplace, SchemaKind.out) and not f.func.name.name.dunder_method
    add_args = RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, '') if is_inplace else ''
    additional_inputs = SELECT.substitute(
        cond='tracer_state->force_outplace',
        true=add_args,
        false='',
    ) if add_args else ''

    return PRE_RECORD_TRACE.substitute(
        set_op_name=format_trace_op_name(f),
        add_trace_inputs=format_trace_inputs(f) + additional_inputs,
        inplace_guard=INPLACE_GUARD.substitute(
            name=cpp.name(f.func),
            mutable_input=f.func.arguments.out[0].name if f.func.arguments.out else 'self',
        ) if is_inplace else '',
    )