Ejemplo n.º 1
0
def get_py_torch_functions(
    python_funcs: Sequence[PythonSignatureNativeFunctionPair],
    method: bool = False,
) -> Sequence[PythonSignatureGroup]:
    """
    Get declarations (grouped by name) which should be generated
    as either functions in the "torch" module or methods on Tensor.
    """

    def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool:
        return (
            should_generate_py_binding(python_func.function)
            and not python_func.function.python_module
            and Variant.function in python_func.function.variants
        )

    def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
        return (
            should_generate_py_binding(python_func.function)
            and not python_func.function.python_module
            and Variant.method in python_func.function.variants
        )

    should_bind = should_bind_method if method else should_bind_function
    return group_overloads([f for f in python_funcs if should_bind(f)])
Ejemplo n.º 2
0
def method_impl(name: BaseOperatorName, module: Optional[str],
                overloads: Sequence[PythonSignatureNativeFunctionPair], *,
                method: bool) -> str:
    uptname = get_upt_name(name, method)
    noarg = is_noarg(overloads)

    method_header: List[str] = [
        'HANDLE_TH_ERRORS',
    ]
    method_header += [
        'Tensor& self = unpackTensor(*args++);',
        '--n_args;',
    ] if method else []

    method_footer = ([] if noarg else ['return mp_const_none;'
                                       ]) + ['END_HANDLE_TH_ERRORS']

    grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(
        overloads)
    is_singleton = len(grouped_overloads) == 1
    signatures: List[str] = []
    dispatch: List[str] = []
    for overload_index, overload in enumerate(grouped_overloads):
        signature = overload.signature.signature_str()
        signatures.append(f'{cpp_string(str(signature))},')
        dispatch_body = emit_dispatch_case(overload)
        dispatch.append(
            PY_VARIABLE_CASE.
            substitute(overload_index=overload_index, body=dispatch_body
                       ) if not is_singleton else dispatch_body)

    if noarg:
        template = PY_VARIABLE_METHOD_NOARGS
    elif is_singleton:
        template = PY_VARIABLE_METHOD_VARARGS_SINGLETON
    else:
        template = PY_VARIABLE_METHOD_VARARGS

    return template.substitute(
        name=name,
        uptname=uptname,
        method_header=method_header,
        max_args=max(map(lambda o: o.signature.arguments_count(), overloads)),
        signatures=signatures,
        dispatch=dispatch,
        method_footer=method_footer,
        self_='*args' if method else 'nullptr',
    )