예제 #1
0
파일: gen.py 프로젝트: wzjahucm/pytorch
    def go(f: NativeFunction) -> Optional[str]:
        if Variant.method not in f.variants:
            return None

        assert not f.func.is_out_fn()
        assert len(f.func.arguments) > 0
        assert sum(a.name == 'self' for a in f.func.arguments) == 1

        name = cpp.name(f.func)
        cpp_returns_type = cpp.returns_type(f.func.returns)
        cpp_args = cpp.arguments(f.func, method=True)
        cpp_args_exclude_this = [a for a in cpp_args if not isinstance(a.argument, ThisArgument)]
        cpp_args_exclude_this_str = ', '.join(str(a) for a in cpp_args_exclude_this)

        if target is Target.DECLARATION:
            return f"{cpp_returns_type} {name}({cpp_args_exclude_this_str}) const;"

        assert target is Target.DEFINITION

        dispatcher_exprs = dispatcher.cpparguments_exprs(cpp_args)
        cpp_args_exclude_this_str_no_default = ', '.join(a.str_no_default() for a in cpp_args_exclude_this)
        dispatcher_returns_type = dispatcher.returns_type(f.func.returns)
        dispatcher_types_str = ', '.join(map(lambda a: a.type, dispatcher_exprs))
        dispatcher_exprs_str = ', '.join(map(lambda a: a.expr, dispatcher_exprs))

        return f"""
예제 #2
0
    def go(f: NativeFunction) -> Optional[str]:
        if f.manual_kernel_registration:
            return None
        if Variant.function not in f.variants:
            return None

        name = cpp.name(f.func)

        cpp_returns_type = cpp.returns_type(f.func.returns)
        cpp_args = cpp.arguments(f.func)
        cpp_args_str = ', '.join(map(str, cpp_args))

        if target is Target.DECLARATION:
            return f"CAFFE2_API {cpp_returns_type} {name}({cpp_args_str});"

        assert target is Target.DEFINITION

        dispatcher_exprs = dispatcher.cpparguments_exprs(cpp_args)
        cpp_args_str_no_default = ', '.join(
            map(lambda a: a.str_no_default(), cpp_args))
        dispatcher_returns_type = dispatcher.returns_type(f.func.returns)
        dispatcher_types_str = ', '.join(
            map(lambda a: a.type, dispatcher_exprs))
        dispatcher_exprs_str = ', '.join(
            map(lambda a: a.expr, dispatcher_exprs))

        return f"""
예제 #3
0
파일: gen.py 프로젝트: wzjahucm/pytorch
def compute_declaration_yaml(f: NativeFunction) -> object:
    returns, name_to_field_name = compute_returns_yaml(f)

    # These sets are used to conveniently test if an argument is a
    # kwarg-only or out argument
    kwarg_only_set = set(a.name for a in f.func.kwarg_only_arguments)
    out_arg_set = set(a.name for a in f.func.out_arguments)

    cpp_args = cpp.arguments(f.func)
    arguments = [
        compute_cpp_argument_yaml(
            cpp_a, schema_order=False,
            kwarg_only_set=kwarg_only_set, out_arg_set=out_arg_set, name_to_field_name=name_to_field_name)
        for cpp_a in cpp_args
    ]

    # See Note [Byte-for-byte compatibility]
    # NB: NOT actually schema order.  This is almost certainly a BUG.
    schema_order_jit_arguments = list(itertools.chain(f.func.arguments, f.func.out_arguments, f.func.kwarg_only_arguments))

    schema_order_arguments = [
        compute_argument_yaml(
            a, schema_order=True,
            kwarg_only_set=kwarg_only_set, out_arg_set=out_arg_set, name_to_field_name=name_to_field_name)
        for a in schema_order_jit_arguments
    ]

    cpp_schema_order_types = [cpp.argument(a).type for a in schema_order_jit_arguments]
    cpp_returns = cpp.returns_type(f.func.returns)
    schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"

    is_factory_method = any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args) \
        and Variant.method not in f.variants

    return OrderedDict([
        ('name', cpp.name(f.func)),
        ('operator_name', str(f.func.name.name)),
        ('overload_name', str(f.func.name.overload_name)),
        ('use_c10_dispatcher', f.use_c10_dispatcher.name),
        ('manual_kernel_registration', f.manual_kernel_registration),
        ('category_override', f.category_override if f.category_override is not None else ''),
        ('matches_jit_signature', True),
        ('schema_string', f'aten::{f.func}'),
        ('arguments', arguments),
        ('schema_order_cpp_signature', schema_order_cpp_signature),
        ('schema_order_arguments', schema_order_arguments),
        ('method_of', compute_method_of_yaml(f.variants)),
        ('mode', 'native'),
        ('python_module', '' if f.python_module is None else f.python_module),
        ('returns', returns),
        ('inplace', f.func.name.name.inplace),
        ('is_factory_method', is_factory_method),
        # Note [Abstract ATen methods]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # An abstract ATen method is one whose dispatch differs between
        # types.  These are implemented in derived types (with a
        # standard (throwing) definition in Type).  A concrete ATen
        # method is one which has the same dispatch for all types;
        # we just implement it in the base Type.  This is exposed
        # in Declarations.yaml via a field named 'abstract'.
        #
        # Although this is what we have historically exposed, it is
        # actually not all that useful for end users, who are also interested
        # whether or not there is an explicit entry in derivatives.yaml
        # for the entry or not (as this affects whether or not the operation is
        # overrideable or not.)  Once this all gets cleaned up, this
        # property will be obsolete.
        ('abstract', f.dispatch is not None),
        ('device_guard', f.device_guard),
        ('with_gil', False),
        ('deprecated', False),
        ('has_math_kernel', f.dispatch is not None and 'Math' in f.dispatch),
    ])