Exemple #1
0
def cpparguments_exprs(func: FunctionSchema, *, method: bool,
                       api_is_faithful: bool) -> Sequence[DispatcherExpr]:
    dispatcher_calling_convention_is_faithful = local.use_c10_dispatcher(
    ).dispatcher_uses_new_style()
    arguments = cpp.group_arguments(
        func,
        method=method,
        faithful=dispatcher_calling_convention_is_faithful)

    if api_is_faithful:
        argument_packs = tuple(cpp.argument_faithful(a) for a in arguments)
    else:
        argument_packs = tuple(cpp.argument(a) for a in arguments)

    return _cpparguments_exprs(argument_packs)
Exemple #2
0
def unpack_args(f: NativeFunction) -> Tuple[List[str], List[Binding]]:
    body: List[str] = []
    unpacked_bindings: List[Binding] = []

    if f.use_c10_dispatcher.dispatcher_uses_new_style():
        bindings = [r for a in f.func.schema_order_arguments()
                    for r in cpp.argument(a,
                                          method=False,
                                          cpp_no_default_args=set(),
                                          faithful=False,
                                          has_tensor_options=False)]
    else:
        sig_group = CppSignatureGroup.from_native_function(f, method=False)
        bindings = list(sig_group.signature.arguments())

    for i, binding in enumerate(bindings):
        assert not isinstance(binding.argument, SelfArgument)
        if isinstance(binding.argument, TensorOptionsArguments):
            raise RuntimeError("VariableKernel shouldn't take TensorOptions")

        is_nullable = binding.argument.type.is_nullable()
        if not binding.argument.type.is_tensor_like() or is_nullable:
            unpacked_bindings.append(binding)
            continue

        is_tensor_list = is_tensor_list_type(binding.argument.type)
        ref = (not is_nullable) and not is_tensor_list
        suffix = '_opt' if is_nullable and not is_tensor_list else ''
        body.append(UNPACK_TENSOR.substitute(
            arg_name=binding.name,
            arg_pos=i,
            suffix=suffix,
            ref='&' if ref else '',
        ))
        unpacked_bindings.append(Binding(
            name=binding.name + '_',
            ctype=binding.ctype,
            argument=binding.argument,
            default=binding.default,
        ))

    return body, unpacked_bindings
Exemple #3
0
def cpparguments_exprs(func: FunctionSchema, *, method: bool,
                       api_is_faithful: bool) -> Sequence[DispatcherExpr]:
    dispatcher_is_faithful = local.use_c10_dispatcher(
    ).dispatcher_uses_new_style()

    arguments: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
    if dispatcher_is_faithful:
        arguments.extend(func.arguments.non_out)
        arguments.extend(func.arguments.out)
    else:
        arguments.extend(func.arguments.out)
        arguments.extend(func.arguments.non_out)

    if api_is_faithful:
        argument_packs = tuple(
            cpp.argument_faithful(a, method=method) for a in arguments)
    else:
        argument_packs = tuple(
            cpp.argument(a, method=method) for a in arguments)

    return _cpparguments_exprs(argument_packs)
Exemple #4
0
def compute_declaration_yaml(f: NativeFunction) -> object:
    returns, name_to_field_name = compute_returns_yaml(f)

    # These sets are used to conveniently test if an argument is a
    # kwarg-only or out argument
    kwarg_only_set = set(a.name for a in f.func.kwarg_only_arguments)
    out_arg_set = set(a.name for a in f.func.out_arguments)

    sig_group = CppSignatureGroup.from_schema(f.func, method=False)
    cpp_args = sig_group.signature.arguments()
    arguments = [
        compute_cpp_argument_yaml(cpp_a,
                                  schema_order=False,
                                  kwarg_only_set=kwarg_only_set,
                                  out_arg_set=out_arg_set,
                                  name_to_field_name=name_to_field_name)
        for cpp_a in cpp_args
    ]

    schema_order_jit_arguments = list(f.func.schema_order_arguments())

    schema_order_arguments = [
        compute_argument_yaml(a,
                              schema_order=True,
                              kwarg_only_set=kwarg_only_set,
                              out_arg_set=out_arg_set,
                              name_to_field_name=name_to_field_name)
        for a in schema_order_jit_arguments
    ]

    cpp_schema_order_types = [
        cpp.argument(a).type for a in schema_order_jit_arguments
    ]
    cpp_returns = cpp.returns_type(f.func.returns)
    schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"

    is_factory_method = any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args) \
        and Variant.method not in f.variants

    return OrderedDict([
        ('name', cpp.name(f.func)),
        ('operator_name', str(f.func.name.name)),
        ('overload_name', str(f.func.name.overload_name)),
        ('use_c10_dispatcher', f.use_c10_dispatcher.name),
        ('manual_kernel_registration', f.manual_kernel_registration),
        ('category_override',
         f.category_override if f.category_override is not None else ''),
        ('matches_jit_signature', True),
        ('schema_string', f'aten::{f.func}'),
        ('arguments', arguments),
        ('schema_order_cpp_signature', schema_order_cpp_signature),
        ('schema_order_arguments', schema_order_arguments),
        ('method_of', compute_method_of_yaml(f.variants)),
        ('mode', 'native'),
        ('python_module', '' if f.python_module is None else f.python_module),
        ('returns', returns),
        ('inplace', f.func.name.name.inplace),
        ('is_factory_method', is_factory_method),
        # Note [Abstract ATen methods]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # An abstract ATen method is one whose dispatch differs between
        # types.  These are implemented in derived types (with a
        # standard (throwing) definition in Type).  A concrete ATen
        # method is one which has the same dispatch for all types;
        # we just implement it in the base Type.  This is exposed
        # in Declarations.yaml via a field named 'abstract'.
        #
        # Although this is what we have historically exposed, it is
        # actually not all that useful for end users, who are also interested
        # whether or not there is an explicit entry in derivatives.yaml
        # for the entry or not (as this affects whether or not the operation is
        # overrideable or not.)  Once this all gets cleaned up, this
        # property will be obsolete.
        ('abstract', f.dispatch is not None),
        ('device_guard', f.device_guard),
        ('with_gil', False),
        ('deprecated', False),
        ('has_math_kernel', f.dispatch is not None and 'Math' in f.dispatch),
    ])
Exemple #5
0
def compute_declaration_yaml(f: NativeFunction) -> object:
    returns, name_to_field_name = compute_returns_yaml(f)

    # These sets are used to conveniently test if an argument is a
    # kwarg-only or out argument
    kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only)
    out_arg_set = set(a.name for a in f.func.arguments.out)

    sig_group = CppSignatureGroup.from_native_function(f,
                                                       method=False,
                                                       fallback_binding=False)
    cpp_args = sig_group.signature.arguments()
    arguments = [
        compute_cpp_argument_yaml(cpp_a,
                                  schema_order=False,
                                  kwarg_only_set=kwarg_only_set,
                                  out_arg_set=out_arg_set,
                                  name_to_field_name=name_to_field_name)
        for cpp_a in cpp_args
    ]

    schema_order_jit_arguments = list(f.func.schema_order_arguments())

    schema_order_arguments = [
        compute_argument_yaml(a,
                              schema_order=True,
                              kwarg_only_set=kwarg_only_set,
                              out_arg_set=out_arg_set,
                              name_to_field_name=name_to_field_name)
        for a in schema_order_jit_arguments
    ]

    cpp_schema_order_types = [
        # NB: method here doesn't matter
        r.type for a in schema_order_jit_arguments
        for r in cpp.argument(a,
                              method=False,
                              cpp_no_default_args=set(),
                              faithful=False,
                              has_tensor_options=False)
    ]

    cpp_returns = cpp.returns_type(f.func.returns).cpp_type()
    schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"

    is_factory_method = any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args) \
        and Variant.method not in f.variants

    return OrderedDict([
        ('name', cpp.name(f.func)),
        ('operator_name', str(f.func.name.name)),
        ('overload_name', str(f.func.name.overload_name)),
        ('manual_kernel_registration', f.manual_kernel_registration),
        ('category_override',
         f.category_override if f.category_override is not None else ''),
        ('schema_string', f'aten::{f.func}'),
        ('arguments', arguments),
        ('schema_order_cpp_signature', schema_order_cpp_signature),
        ('schema_order_arguments', schema_order_arguments),
        ('method_of', compute_method_of_yaml(f.variants)),
        ('mode', 'native'),
        ('python_module', '' if f.python_module is None else f.python_module),
        ('returns', returns),
        ('inplace', f.func.name.name.inplace),
        ('is_factory_method', is_factory_method),
        ('abstract', f.is_abstract),
        ('device_guard', f.device_guard),
        ('with_gil', False),
        ('deprecated', False),
        ('has_math_kernel', DispatchKey.CompositeImplicitAutograd
         in f.dispatch),
    ])