def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType: r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if t == BaseType(BaseTy.Scalar): return NamedCType(binds, compute_t) elif t == BaseType(BaseTy.Tensor): return NamedCType(binds, compute_t) else: raise AssertionError(f"unrecognized type {repr(t)}")
def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]: r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if t == BaseType(BaseTy.Scalar): return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) elif t == BaseType(BaseTy.Tensor): return None else: raise AssertionError(f"unrecognized type {repr(t)}")
def returntype_type(t: Type) -> str: r = cpp.valuetype_type(t) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: return 'TensorMeta' elif isinstance(t, ListType): raise NotImplementedError("list returns not supported yet") raise AssertionError(f"unrecognized return type {t}")
def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if t == BaseType(BaseTy.Scalar): return NamedCType(binds, BaseCType(opmath_type(scalar_t))) elif t == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(opmath_type(scalar_t))) else: raise AssertionError(f"unrecognized type {repr(t)}")
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # If it's a value type, do the value type translation r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) elif t.name == BaseTy.Scalar: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if t.elem == BaseType(BaseTy.Tensor): raise AssertionError( "optional tensor not supported by structured yet; to implement this " "add OptionalTensor c.f. https://github.com/pytorch/pytorch/issues/51456" ) elif t.elem == BaseType(BaseTy.Scalar): raise AssertionError( "optional scalar not supported by structured yet" ) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): if t.elem == BaseType(BaseTy.Tensor): raise AssertionError( "list of tensor not supported by structured yet; to implement this " "resolve torch::List issue, see " "https://fb.workplace.com/groups/894363187646754/permalink/1149276442155426" ) # TODO: delete these special cases; see tools.codegen.api.cpp--these # must be changed in tandem, but there are problems; see # https://github.com/pytorch/pytorch/pull/51485 elif str(t.elem) == 'int': return NamedCType(binds, BaseCType(intArrayRefT)) elif str(t.elem) == 'Dimname': return NamedCType(binds, BaseCType(dimnameListT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # If it's a value type, do the value type translation r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) elif t.name == BaseTy.Scalar: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if t.elem == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(optionalTensorRefT)) elif t.elem == BaseType(BaseTy.Scalar): return NamedCType(binds, BaseCType(optionalScalarRefT)) elif isinstance(t.elem, ListType) and str(t.elem.elem) == 'int': return NamedCType(binds, BaseCType(optionalIntArrayRefT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): if t.elem == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(iTensorListRefT)) # TODO: delete these special cases; see tools.codegen.api.cpp--these # must be changed in tandem, but there are problems; see # https://github.com/pytorch/pytorch/pull/51485 elif str(t.elem) == 'int': return NamedCType(binds, BaseCType(intArrayRefT)) elif str(t.elem) == 'Dimname': return NamedCType(binds, BaseCType(dimnameListT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}")