Beispiel #1
0
def argumenttype_ivalue_convert(
        t: Type,
        arg_name: str,
        *,
        mutable: bool = False) -> Tuple[str, CType, List[str], List[str]]:
    ctype = cpp.argumenttype_type(t=t, mutable=mutable, binds=arg_name).type

    if isinstance(t, BaseType):
        out_name = f"{arg_name}_base"
        code, decl = _gen_code_base_type(arg_name=arg_name,
                                         out_name=out_name,
                                         ctype=ctype)
    elif isinstance(t, OptionalType):
        out_name = f"{arg_name}_opt_out"
        code, decl = _gen_code_optional_type(arg_name=arg_name,
                                             out_name=out_name,
                                             t=t,
                                             ctype=ctype)
    elif isinstance(t, ListType):
        out_name = f"{arg_name}_list_out"
        code, decl = _gen_code_list_type(arg_name=arg_name,
                                         out_name=out_name,
                                         t=t,
                                         ctype=ctype)
    else:
        raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")
    return out_name, ctype, code, decl
Beispiel #2
0
def dynamic_type(t: Type) -> str:
    if isinstance(t, OptionalType):
        return dynamic_type(t.elem)
    # Note we don't use t.is_tensor_like() here because it would
    # also include Tensor[]
    if str(t) == 'Tensor':
        return 'Tensor'
    return cpp.argumenttype_type(t, mutable=False)
Beispiel #3
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
    if str(t) == 'Tensor?':
        if mutable:
            return MutRefCType(BaseCType('Tensor', binds))
        else:
            return ConstRefCType(BaseCType('Tensor', binds))
    elif str(t) == 'Tensor?[]':
        return BaseCType('const c10::List<c10::optional<Tensor>> &', binds)
    return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
Beispiel #4
0
def argumenttype_type(t: Type, *, mutable: bool) -> str:
    if str(t) == 'Tensor?':
        if mutable:
            return 'Tensor &'
        else:
            return 'const Tensor &'
    elif str(t) == 'Tensor?[]':
        return 'TensorList'
    return cpp.argumenttype_type(t, mutable=mutable)
Beispiel #5
0
def argumenttype_type(t: Type, *, mutable: bool) -> str:
    if local.use_c10_dispatcher() is UseC10Dispatcher.full:
        # This is a faux amis.  If it makes sense in the future to add
        # more special cases here, or invert things so cpp.argument_type
        # calls this, or just completely inline the function, please do
        # it.
        return cpp.argumenttype_type(t, mutable=mutable)
    else:
        # This is real sharing.  If you're modifying this path, ask
        # yourself why you are changing the legacy dispatcher protocol
        # here and not in legacy_dispatcher.
        return legacy_dispatcher.argumenttype_type(t, mutable=mutable)
Beispiel #6
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
    if local.use_c10_dispatcher().dispatcher_uses_new_style():
        # This is a faux amis.  If it makes sense in the future to add
        # more special cases here, or invert things so cpp.argument_type
        # calls this, or just completely inline the function, please do
        # it.
        return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
    else:
        # This is real sharing.  If you're modifying this path, ask
        # yourself why you are changing the native functions protocol
        # here and not in native.
        return native.argumenttype_type(t, mutable=mutable, binds=binds)
Beispiel #7
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
    if str(t) == 'Tensor?':
        tensor_type: OptionalCType = OptionalCType(BaseCType('Tensor', binds))
        if mutable:
            return MutRefCType(tensor_type)
        else:
            return ConstRefCType(tensor_type)
    elif str(t) == 'Tensor?[]':
        return ConstRefCType(BaseCType("c10::List<c10::optional<Tensor>>", binds))
    elif str(t) == 'Scalar':
        return ConstRefCType(BaseCType('Scalar', binds))
    elif str(t) == 'Scalar?':
        return ConstRefCType(OptionalCType(BaseCType('Scalar', binds)))
    return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
Beispiel #8
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
    if str(t) == 'Tensor?':
        tensor_type: CType = BaseCType('Tensor', binds)
        if local.use_c10_dispatcher(
        ) is not UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
            tensor_type = OptionalCType(tensor_type)
        if mutable:
            return MutRefCType(tensor_type)
        else:
            return ConstRefCType(tensor_type)
    elif str(t) == 'Tensor?[]':
        return ConstRefCType(
            BaseCType("c10::List<c10::optional<Tensor>>", binds))
    return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
Beispiel #9
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
    if str(t) == 'Tensor?':
        tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
        if mutable and not local.use_const_ref_for_mutable_tensors():
            return NamedCType(binds, MutRefCType(tensor_type))
        else:
            return NamedCType(binds, ConstRefCType(tensor_type))
    elif str(t) == 'Tensor?[]':
        return NamedCType(
            binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))))
    elif str(t) == 'Scalar':
        return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
    elif str(t) == 'Scalar?':
        return NamedCType(binds,
                          ConstRefCType(OptionalCType(BaseCType(scalarT))))
    return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
Beispiel #10
0
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
    # This is a faux amis.  If it makes sense in the future to add
    # more special cases here, or invert things so cpp.argument_type
    # calls this, or just completely inline the function, please do
    # it.
    return cpp.argumenttype_type(t, mutable=mutable, binds=binds)