def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # If it's a value type, do the value type translation r = valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: if mutable and not local.use_const_ref_for_mutable_tensors(): return NamedCType(binds, MutRefCType(BaseCType(tensorT))) else: return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) elif t.name == BaseTy.Scalar: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if str(t.elem) == 'Tensor': if mutable and not local.use_const_ref_for_mutable_tensors(): return NamedCType(binds, MutRefCType( BaseCType(tensorT))) # TODO: fix this discrepancy else: return NamedCType( binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))) elif str(t.elem) == 'Scalar': return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): # TODO: remove these special cases, ArrayRef fallthrough works fine if str(t.elem) == 'int': return NamedCType(binds, BaseCType(intArrayRefT)) elif str(t.elem) == 'Tensor': return NamedCType(binds, BaseCType(tensorListT)) elif str(t.elem) == 'Scalar': return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) elif str(t.elem) == 'Dimname': return NamedCType(binds, BaseCType(dimnameListT)) elif str(t.elem) == 'Tensor?': return NamedCType( binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # If it's a value type, do the value type translation r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) elif t.name == BaseTy.Scalar: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if t.elem == BaseType(BaseTy.Tensor): raise AssertionError( "optional tensor not supported by structured yet; to implement this " "add OptionalTensor c.f. https://github.com/pytorch/pytorch/issues/51456" ) elif t.elem == BaseType(BaseTy.Scalar): raise AssertionError( "optional scalar not supported by structured yet" ) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): if t.elem == BaseType(BaseTy.Tensor): raise AssertionError( "list of tensor not supported by structured yet; to implement this " "resolve torch::List issue, see " "https://fb.workplace.com/groups/894363187646754/permalink/1149276442155426" ) # TODO: delete these special cases; see tools.codegen.api.cpp--these # must be changed in tandem, but there are problems; see # https://github.com/pytorch/pytorch/pull/51485 elif str(t.elem) == 'int': return NamedCType(binds, BaseCType(intArrayRefT)) elif str(t.elem) == 'Dimname': return NamedCType(binds, BaseCType(dimnameListT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # If it's a value type, do the value type translation r = cpp.valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) elif t.name == BaseTy.Scalar: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if t.elem == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(optionalTensorRefT)) elif t.elem == BaseType(BaseTy.Scalar): return NamedCType(binds, BaseCType(optionalScalarRefT)) elif isinstance(t.elem, ListType) and str(t.elem.elem) == 'int': return NamedCType(binds, BaseCType(optionalIntArrayRefT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): if t.elem == BaseType(BaseTy.Tensor): return NamedCType(binds, BaseCType(iTensorListRefT)) # TODO: delete these special cases; see tools.codegen.api.cpp--these # must be changed in tandem, but there are problems; see # https://github.com/pytorch/pytorch/pull/51485 elif str(t.elem) == 'int': return NamedCType(binds, BaseCType(intArrayRefT)) elif str(t.elem) == 'Dimname': return NamedCType(binds, BaseCType(dimnameListT)) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def save_var(var: SavedAttribute, is_output: bool) -> None: name = var.nctype.name type = var.nctype.type should_append_getsetdef = True should_append_raw_getsetdef = False if type == BaseCType(tensorT) or type == OptionalCType(BaseCType(tensorT)) or \ type == MutRefCType(OptionalCType(BaseCType(tensorT))) or \ (type == BaseCType(scalarT) and is_output): saved_variables.append(f'SavedVariable {name}_;') release_variables.append(f'{name}_.reset_data();') ptr = 'shared_from_this()' if is_output else '' unpack.append(f'auto {name} = {name}_.unpack({ptr});') getter_definitions.append( GETTER_DEFINITION_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_SAVEDVAR)) getter_definitions.append( GETTER_DEFINITION_RAW_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR)) should_append_raw_getsetdef = True elif type == BaseCType(tensorListT): saved_variables.append(f'std::vector<SavedVariable> {name}_;') saved_variables.append(f'bool {name}_released_ = false;') # Just clear() is sufficient, we don't need to loop and clear each variable. # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. release_variables.append(f'{name}_.clear();') release_variables.append(f'{name}_released_ = true;') unpack.append(f'auto {name} = unpack_list({name}_);') asserts.append( f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);') getter_definitions.append( GETTER_DEFINITION_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR)) getter_definitions.append( GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR)) should_append_raw_getsetdef = True elif type == ListCType(OptionalCType(BaseCType(tensorT))): saved_variables.append(f'std::vector<SavedVariable> {name}_;') saved_variables.append(f'bool {name}_released_ = false;') # Just clear() is sufficient, we don't need to loop and clear each variable. # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. release_variables.append(f'{name}_.clear();') release_variables.append(f'{name}_released_ = true;') unpack.append(f'auto {name} = unpack_opt_list({name}_);') asserts.append( f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);') getter_definitions.append( GETTER_DEFINITION_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR)) getter_definitions.append( GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR)) should_append_raw_getsetdef = True elif type == BaseCType(intArrayRefT): saved_variables.append(f'std::vector<int64_t> {name};') getter_definitions.append( GETTER_DEFINITION.substitute(op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG)) elif type == OptionalCType(BaseCType(intArrayRefT)): saved_variables.append(f'c10::OptionalArray<int64_t> {name};') getter_definitions.append( GETTER_DEFINITION_OPT_ARRAYREF.substitute( op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG)) elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))): saved_variables.append(f'c10::OptionalArray<double> {name};') getter_definitions.append( GETTER_DEFINITION_OPT_ARRAYREF.substitute( op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE)) elif type == BaseCType(intT): saved_variables.append(f'{type.cpp_type()} {name} = 0;') getter_definitions.append( GETTER_DEFINITION.substitute(op=info.op, name=name, body=GETTER_BODY_INT64_T)) elif type == BaseCType(stringT): saved_variables.append(f'std::string {name};') getter_definitions.append( GETTER_DEFINITION.substitute(op=info.op, name=name, body=GETTER_BODY_STRING)) elif type == OptionalCType(BaseCType(stringT)): saved_variables.append(f'c10::optional<std::string> {name};') getter_definitions.append( GETTER_DEFINITION_OPT.substitute(op=info.op, name=name, body=GETTER_BODY_STRING)) else: saved_variables.append(f'{type.cpp_type()} {name};') if type in MISC_GETTER_DEFS: getter_def, body = MISC_GETTER_DEFS[type] getter_definitions.append( getter_def.substitute(op=info.op, name=name, body=body)) else: # Types we don't expose python bindings to yet: # TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry, # std::vector<std::vector<int64_t>>, std::vector<at::ScalarType> should_append_getsetdef = False if should_append_getsetdef: py_getsetdef_structs.append( PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name)) if should_append_raw_getsetdef: py_getsetdef_structs.append( PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name))