Exemplo n.º 1
0
def is_supported(
        g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
    base_op_name = ""
    func = None
    if isinstance(g, NativeFunctionsViewGroup):
        base_op_name = g.view.root_name
        func = g.view.func
    else:
        base_op_name = g.out.func.name.name.base
        func = g.out.func
    if config.is_hand_written(g):
        logger.info(f"HAND WRITTEN: {base_op_name}")
        return False
    if base_op_name in BLOCKED_OPS:
        logger.info(f"BLOCKED: {base_op_name}")
        return False
    for arg in func.schema_order_arguments():
        maybe_method = ivalue_type_conversion_method(arg.type)
        if not maybe_method:
            # Type converting is unsupported yet.
            logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(func)}")
            return False

    if isinstance(g, NativeFunctionsViewGroup):
        # TODO: stop doing type tests by converting to C++ and then testing
        # the string, just test the dang thing directly
        if "at::Tensor" != cpp.returns_type(func.returns,
                                            symint=False).cpp_type():
            # Returns a non-Tensor value.
            logger.info(f"NON-TENSOR RET TYPE: {str(func)}")
            return False
        return True

    # For out variant ops, we need to check the arguments of its functional func.
    for arg in g.functional.func.schema_order_arguments():
        maybe_method = ivalue_type_conversion_method(arg.type)
        if not maybe_method:
            # Type converting is unsupported yet.
            logger.info(
                f"NOT SUPPORTED TYPE CONVERTING: {str(g.functional.func)}")
            return False

    if not g.structured:
        # In case of unstructured op, we check if it has out variant implementation.
        # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
        # parameter.
        if (not hasattr(g, "out")
                or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
                or not str(func.name).endswith(".out")):
            return False
    # TODO: stop type testing by converting to C++
    if "at::Tensor &" != cpp.returns_type(func.returns,
                                          symint=False).cpp_type():
        logger.info(f"NON_TENSOR RET TYPE: {str(func)}")
        return False
    if has_alias(func.arguments.non_out):
        # This op may create an alias of inputs.
        logger.info(f"INPUTS ALIAS: {base_op_name}")
        return False
    return True
Exemplo n.º 2
0
def is_supported(g: NativeFunctionsGroup) -> bool:
    base_op_name = g.out.func.name.name.base
    if base_op_name in BLOCKED_OPS:
        return False
    if config.is_hand_written(g):
        return False
    if not g.structured:
        # In case of unstructured op, we check if it has out variant implementation.
        # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
        # parameter.
        if (not hasattr(g, "out") or
                not str(g.out.func).endswith("Tensor(a!) out) -> Tensor(a!)")
                or not str(g.out.func.name).endswith(".out")):
            return False
    if has_alias(g.out.func.arguments.non_out):
        # This op may create an alias of inputs.
        return False
    if len(g.out.func.arguments.out) > 1:
        # More than 1 output values.
        return False
    if "at::Tensor &" != cpp.returns_type(g.out.func.returns).cpp_type():
        # Returns a non-Tensor value.
        return False
    for arg in g.out.func.schema_order_arguments():
        maybe_method = ivalue_type_conversion_method(arg.type)
        if not maybe_method:
            # Type converting is unsupported yet.
            return False
    return True
Exemplo n.º 3
0
def inplace_or_view_method_definition(
    fn: NativeFunctionWithDifferentiabilityInfo, ) -> Optional[str]:
    f = fn.func
    if get_view_info(f) is None and (
            # For functions that modify their inputs but don't return them,
            # we can't give them autograd support.
            # See https://github.com/pytorch/pytorch/issues/53796
            not modifies_arguments(f) or len(f.func.returns) == 0):
        return None
    return METHOD_DEFINITION.substitute(
        return_type=cpp.returns_type(f.func.returns).cpp_type(),
        type_wrapper_name=type_wrapper_name(f),
        formals=gen_formals(f),
        type_definition_body=emit_inplace_or_view_body(fn),
    )
Exemplo n.º 4
0
def method_definition(f: NativeFunction) -> str:
    assert cpp.name(f.func) not in MANUAL_TRACER

    formals = ", ".join(
        # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
        # See Note [Plumbing Keys Through The Dispatcher] for details.
        ["c10::DispatchKeySet ks"] + [
            f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}'
            for a in f.func.schema_order_arguments()
        ])

    return METHOD_DEFINITION.substitute(
        return_type=cpp.returns_type(f.func.returns).cpp_type(),
        type_wrapper_name=type_wrapper_name(f),
        formals=formals,
        type_definition_body=emit_trace_body(f),
    )
Exemplo n.º 5
0
def is_supported(g: NativeFunctionsGroup) -> bool:
    if not g.structured:
        return False
    if config.is_hand_written(g):
        return False
    if has_alias(g.out.func.arguments.non_out):
        # This op may create an alias of inputs.
        return False
    if len(g.out.func.arguments.out) > 1:
        # More than 1 output values.
        return False
    if "at::Tensor &" != cpp.returns_type(g.out.func.returns).cpp_type():
        # Returns a non-Tensor value.
        return False
    for arg in g.out.func.schema_order_arguments():
        maybe_method = ivalue_type_conversion_method(arg.type)
        if not maybe_method:
            # Type converting is unsupported yet.
            return False
    return True
Exemplo n.º 6
0
def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
    return cpp.returns_type(rs, symint=symint)
Exemplo n.º 7
0
def returns_type(rs: Sequence[Return]) -> CType:
    # At present, there is no difference. But there could be!
    return cpp.returns_type(rs)
Exemplo n.º 8
0
def returns_type(rs: Sequence[Return]) -> CType:
    return cpp.returns_type(rs)