def enforce_same_tensorimpl_and_storage(call: str, unpacked_bindings: List[Binding]) -> str: save_ptrs_stmts: List[str] = [] enforce_same_ptrs_stmts: List[str] = [] if cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: for unpacked_binding in unpacked_bindings: arg = unpacked_binding.name noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref() if noref_cpp_type == BaseCType(tensorListT): save_ptrs_stmts += [SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)] enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)] elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))): save_ptrs_stmts += [SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)] enforce_same_ptrs_stmts += [ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)] elif noref_cpp_type == BaseCType(tensorT): save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), SAVE_TENSOR_IMPL.substitute(tensor_name=arg)] enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSOR_STORAGE.substitute(tensor_name=arg), ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)] assert (save_ptrs_stmts and enforce_same_ptrs_stmts) or (not save_ptrs_stmts and not enforce_same_ptrs_stmts) if save_ptrs_stmts and enforce_same_ptrs_stmts: call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=save_ptrs_stmts) + \ call + \ RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts) return call
def process_ir_type( typ: Type) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]: """ This function takes a type from NativeFunctions and converts it for use with lazy tensor codegen. Type conversion for lazy currently consists of (1) changing at::Tensors into lazy::Values (2) wrapping everything in a BaseCType (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef) (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.) There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like' This is incomplete- there are assertions in places that it's expected to need to add more types as the codegen is used with more operators. """ if isinstance(typ, BaseType): if typ.name == BaseTy.Tensor: return BaseCType(valueT) elif typ.name == BaseTy.Scalar: # at::scalar has special handling, # and is wrapped in an lazy::Value just like at::tensor return BaseCType(valueT) elif typ.name == BaseTy.ScalarType: return BaseCType(scalarTypeT) elif typ.name == BaseTy.int: return BaseCType(longT) elif typ.name == BaseTy.bool: return BaseCType(boolT) elif typ.name == BaseTy.float: return BaseCType(doubleT) elif typ.name == BaseTy.str: return BaseCType(stringT) elif typ.name == BaseTy.Device: return BaseCType(deviceT) elif typ.name == BaseTy.Layout: return BaseCType(layoutT) elif typ.name == BaseTy.MemoryFormat: return BaseCType(memoryFormatT) else: raise AssertionError(f"TODO add support for type {repr(typ)}") elif isinstance(typ, OptionalType): return OptionalCType(process_ir_type(typ.elem)) elif isinstance(typ, ListType): if str(typ.elem) == 'Tensor?': # TODO(whc) is this actually correct? or should it use a Vector like above return ListCType(OptionalCType(BaseCType(valueT))) elif str(typ.elem) == 'Tensor': # this is a TensorList which comes in from GetTensorList as a Value return BaseCType(tensorListValueT) else: return VectorCType(process_ir_type(typ.elem)) else: raise AssertionError(f"unrecognized type {repr(typ)}")
def process_ir_type(typ: Type) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]: """ This function takes a type from NativeFunctions and converts it for use with lazy tensor codegen. Currently its output is used in several places, and so far it has been possible for them to all use the same conversions, but that may not be optimal or possible in the finished system. Type conversion for lazy currently consists of (1) changing Tensor-like things into Value-like things (2) wrapping everything in a BaseCType (3) making reference types into values (e.g. vector instead of IntArrayRef) (1) converts Tensors to Values since Values are how Lazy IR represents tensors. There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like' This is incomplete- there are assertions in places that it's expected to need to add more types as the codegen is used with more operators. """ if isinstance(typ, BaseType): if typ.name == BaseTy.Tensor: return BaseCType(valueT) elif typ.name == BaseTy.Scalar: # at::scalar has special handling, # and is wrapped in an IR value just like at::tensor return BaseCType(valueT) elif typ.name == BaseTy.ScalarType: return BaseCType(scalarTypeT) elif typ.name == BaseTy.int: return BaseCType(longT) elif typ.name == BaseTy.bool: return BaseCType(boolT) elif typ.name == BaseTy.float: return BaseCType(doubleT) elif typ.name == BaseTy.str: return BaseCType(stringT) elif typ.name == BaseTy.Device: return BaseCType(deviceT) elif typ.name == BaseTy.Layout: return BaseCType(layoutT) else: raise AssertionError(f"TODO add support for type {repr(typ)}") elif isinstance(typ, OptionalType): return OptionalCType(process_ir_type(typ.elem)) elif isinstance(typ, ListType): if str(typ.elem) == 'Tensor?': # TODO(whc) is this actually correct? or should it use a Vector like above return ListCType(OptionalCType(BaseCType(valueT))) else: return VectorCType(process_ir_type(typ.elem)) else: raise AssertionError(f"unrecognized type {repr(typ)}")
def emit_fw_derivatives() -> List[str]: content: List[str] = [] fw_grad_setters: List[str] = [] for derivative in fw_derivatives: res = derivative.var_name if f.func.name.name.inplace: # TODO update this when inplace namings are unified res = "self" assert derivative.required_inputs_fw_grad is not None unpacked_arguments = "" for inp in differentiable_inputs: if inp.name in derivative.required_inputs_fw_grad: unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(inp=inp.name) if inp.name in (derivative.required_inputs_primal or []): unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(inp=inp.name) if derivative.required_original_self_value: unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(inp="original_self") unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(inp="original_self") elif inplace and derivative.is_reusing_outplace_formula: # The gradient wasn't already cloned, do it if grad mode is enabled unpacked_arguments += "self_t = GradMode::is_enabled() ? self_t.clone() : self_t;" if inplace: is_inplace_str = "true" else: is_inplace_str = "false" if isinstance(derivative.var_type, BaseType) and derivative.var_type.is_tensor_like(): # Is there a way to get from BaseType to BaseCType opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type() fw_grad_setter = FW_DERIVATIVE_SETTER_TENSOR.substitute(out_arg=res, is_inplace=is_inplace_str) elif isinstance(derivative.var_type, ListType) and derivative.var_type.is_tensor_like(): opt_res_grad_type = OptionalCType(ListCType(BaseCType(tensorT))).cpp_type() fw_grad_setter = FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(out_arg=res, is_inplace=is_inplace_str) else: raise RuntimeError("Unsupported output type for forward derivative") fw_grad_opt_definition = f"{opt_res_grad_type} {res}_new_fw_grad_opt = c10::nullopt;" # View ops create fw_grad that already is a view of the base's fw_grad so just use that content.append(FW_DERIVATIVE_TEMPLATE.substitute( fw_grad_opt_definition=fw_grad_opt_definition, requires_fw_grad=get_any_has_forward_grad_name(derivative.var_name), formula=derivative.formula, out_arg=res, unpacked_arguments=unpacked_arguments)) fw_grad_setters.append(fw_grad_setter) # Set all the grads at the end to avoid: https://github.com/pytorch/pytorch/issues/67367 content.append('\n'.join(fw_grad_setters)) return content
def save_variables( saved_variables: Sequence[SavedAttribute], is_output: bool, guard_for: Callable[[SavedAttribute], Optional[str]] = lambda name: None, ) -> Sequence[str]: # assign the saved variables to the generated grad_fn stmts: List[str] = [] for arg in saved_variables: name = arg.nctype.name.name if isinstance( arg.nctype.name, SpecialArgName) else arg.nctype.name type = arg.nctype.type expr = arg.expr stmts_prepend = None if type == BaseCType(tensorT) or type == OptionalCType(BaseCType(tensorT)) or \ type == MutRefCType(OptionalCType(BaseCType(tensorT))) or (is_output and type == BaseCType(scalarT)): var = name name += '_' if var == 'self' and inplace: stmts_prepend = 'if (!original_self.has_value()) original_self = self.clone()' var = 'original_self.value()' assert not is_output if inplace and is_output: var = 'self' is_inplace_view = f'{var}.is_view()' expr = f'SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})' else: expr = f'SavedVariable({var}, {str(is_output).lower()})' elif type == BaseCType(tensorListT) or type == ListCType( OptionalCType(BaseCType(tensorT))): expr = f'make_saved_variable_list({name})' name += '_' elif type == BaseCType(intArrayRefT): expr = expr + ".vec()" elif type == BaseCType(stringT): expr = f'std::string({expr})' elif type == OptionalCType(BaseCType(stringT)): expr = f'{expr}.has_value() ? c10::optional<std::string>(std::string({expr}.value())) : c10::nullopt' guard = guard_for(arg) if guard is None: if stmts_prepend: stmts.append(f'{stmts_prepend};') stmts.append(f'grad_fn->{name} = {expr};') else: stmts.append(f'if ({guard}) {{') if stmts_prepend: stmts.append(f' {stmts_prepend};') stmts.append(f' grad_fn->{name} = {expr};') stmts.append('}') return stmts
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: if str(t) == 'Tensor?': tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT)) if mutable and not local.use_const_ref_for_mutable_tensors(): return NamedCType(binds, MutRefCType(tensor_type)) else: return NamedCType(binds, ConstRefCType(tensor_type)) elif str(t) == 'Tensor?[]': return NamedCType( binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))) elif str(t) == 'Scalar': return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) elif str(t) == 'Scalar?': return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # If it's a value type, do the value type translation r = valuetype_type(t, binds=binds) if r is not None: return r if isinstance(t, BaseType): if t.name == BaseTy.Tensor: if mutable and not local.use_const_ref_for_mutable_tensors(): return NamedCType(binds, MutRefCType(BaseCType(tensorT))) else: return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) elif t.name == BaseTy.Scalar: return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) else: raise AssertionError(f"base type should have been value type {t}") elif isinstance(t, OptionalType): if str(t.elem) == 'Tensor': if mutable and not local.use_const_ref_for_mutable_tensors(): return NamedCType(binds, MutRefCType( BaseCType(tensorT))) # TODO: fix this discrepancy else: return NamedCType( binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))) elif str(t.elem) == 'Scalar': return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, OptionalCType(elem.type)) elif isinstance(t, ListType): # TODO: remove these special cases, ArrayRef fallthrough works fine if str(t.elem) == 'int': return NamedCType(binds, BaseCType(intArrayRefT)) elif str(t.elem) == 'Tensor': return NamedCType(binds, BaseCType(tensorListT)) elif str(t.elem) == 'Scalar': return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) elif str(t.elem) == 'Dimname': return NamedCType(binds, BaseCType(dimnameListT)) elif str(t.elem) == 'Tensor?': return NamedCType( binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))) elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) return NamedCType(binds, ArrayRefCType(elem.type)) else: raise AssertionError(f"unrecognized type {repr(t)}")
def check_tensorimpl_and_storage(call: str, unpacked_bindings: List[Binding]) -> str: # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] stmts_before_call: List[str] = [] stmts_after_call: List[str] = [] if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: return call # Check properties of inputs (enforce (1)) for unpacked_binding in unpacked_bindings: arg = unpacked_binding.name noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref() if noref_cpp_type == BaseCType(tensorListT): stmts_before_call += [ SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg) ] stmts_after_call += [ ENFORCE_SAME_TENSORLIST_STORAGE.substitute( tensorlist_name=arg), ENFORCE_SAME_TENSORLIST_IMPL.substitute( tensorlist_name=arg) ] elif noref_cpp_type == ListCType(OptionalCType( BaseCType(tensorT))): stmts_before_call += [ SAVE_OPTIONALTENSORLIST_STORAGE.substitute( tensorlist_name=arg), SAVE_OPTIONALTENSORLIST_IMPL.substitute( tensorlist_name=arg) ] stmts_after_call += [ ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute( tensorlist_name=arg), ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute( tensorlist_name=arg) ] elif noref_cpp_type == BaseCType(tensorT): stmts_before_call += [ SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), SAVE_TENSOR_IMPL.substitute(tensor_name=arg) ] stmts_after_call += [ ENFORCE_SAME_TENSOR_STORAGE.substitute( tensor_name=arg, out_tensor_name=arg), ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg) ] assert (stmts_before_call and stmts_after_call) or (not stmts_before_call and not stmts_after_call) # Check properties of outputs (enforce (2), (3)) if not f.func.kind() in (SchemaKind.inplace, SchemaKind.out): base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)? aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None) if aliased_arg_name is not None: aliased_arg_name = unpacked_name(aliased_arg_name) for i, (ret, ret_name) in enumerate( zip(f.func.returns, cpp.return_names(f))): noref_cpp_type = cpp.return_type(ret).remove_const_ref() if noref_cpp_type == BaseCType(tensorT): if aliased_arg_name is not None: assert i == 0, "Expect non-CompositeImplicitAutograd view function {base} to return single output" stmts_after_call += [ ENFORCE_SAME_TENSOR_STORAGE.substitute( tensor_name=aliased_arg_name, out_tensor_name=ret_name) ] else: if type_wrapper_name( f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT: stmts_after_call += [ ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE. substitute(tensor_name=ret_name, fn_name=type_wrapper_name(f)) ] if type_wrapper_name( f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT: stmts_after_call += [ ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE. substitute(tensor_name=ret_name, fn_name=type_wrapper_name(f)) ] # Currently we don't have any functions that return the following types, but # we should update the checks once we do elif noref_cpp_type == ListCType( OptionalCType(BaseCType(tensorT))): raise AssertionError( f"Please add use_count checks for {noref_cpp_type}") elif noref_cpp_type == BaseCType(tensorListT): raise AssertionError( f"Please add use_count checks for {noref_cpp_type}") if stmts_before_call and stmts_after_call: call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call) + \ call + \ RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call) return call
def save_var(var: SavedAttribute, is_output: bool) -> None: name = var.nctype.name type = var.nctype.type should_append_getsetdef = True should_append_raw_getsetdef = False if type == BaseCType(tensorT) or type == OptionalCType(BaseCType(tensorT)) or \ type == MutRefCType(OptionalCType(BaseCType(tensorT))) or \ (type == BaseCType(scalarT) and is_output): saved_variables.append(f'SavedVariable {name}_;') release_variables.append(f'{name}_.reset_data();') ptr = 'shared_from_this()' if is_output else '' unpack.append(f'auto {name} = {name}_.unpack({ptr});') getter_definitions.append( GETTER_DEFINITION_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_SAVEDVAR)) getter_definitions.append( GETTER_DEFINITION_RAW_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR)) should_append_raw_getsetdef = True elif type == BaseCType(tensorListT): saved_variables.append(f'std::vector<SavedVariable> {name}_;') saved_variables.append(f'bool {name}_released_ = false;') # Just clear() is sufficient, we don't need to loop and clear each variable. # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. release_variables.append(f'{name}_.clear();') release_variables.append(f'{name}_released_ = true;') unpack.append(f'auto {name} = unpack_list({name}_);') asserts.append( f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);') getter_definitions.append( GETTER_DEFINITION_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR)) getter_definitions.append( GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR)) should_append_raw_getsetdef = True elif type == ListCType(OptionalCType(BaseCType(tensorT))): saved_variables.append(f'std::vector<SavedVariable> {name}_;') saved_variables.append(f'bool {name}_released_ = false;') # Just clear() is sufficient, we don't need to loop and clear each variable. # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. release_variables.append(f'{name}_.clear();') release_variables.append(f'{name}_released_ = true;') unpack.append(f'auto {name} = unpack_opt_list({name}_);') asserts.append( f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);') getter_definitions.append( GETTER_DEFINITION_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR)) getter_definitions.append( GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR)) should_append_raw_getsetdef = True elif type == BaseCType(intArrayRefT): saved_variables.append(f'std::vector<int64_t> {name};') getter_definitions.append( GETTER_DEFINITION.substitute(op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG)) elif type == OptionalCType(BaseCType(intArrayRefT)): saved_variables.append(f'c10::OptionalArray<int64_t> {name};') getter_definitions.append( GETTER_DEFINITION_OPT_ARRAYREF.substitute( op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG)) elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))): saved_variables.append(f'c10::OptionalArray<double> {name};') getter_definitions.append( GETTER_DEFINITION_OPT_ARRAYREF.substitute( op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE)) elif type == BaseCType(intT): saved_variables.append(f'{type.cpp_type()} {name} = 0;') getter_definitions.append( GETTER_DEFINITION.substitute(op=info.op, name=name, body=GETTER_BODY_INT64_T)) elif type == BaseCType(stringT): saved_variables.append(f'std::string {name};') getter_definitions.append( GETTER_DEFINITION.substitute(op=info.op, name=name, body=GETTER_BODY_STRING)) elif type == OptionalCType(BaseCType(stringT)): saved_variables.append(f'c10::optional<std::string> {name};') getter_definitions.append( GETTER_DEFINITION_OPT.substitute(op=info.op, name=name, body=GETTER_BODY_STRING)) else: saved_variables.append(f'{type.cpp_type()} {name};') if type in MISC_GETTER_DEFS: getter_def, body = MISC_GETTER_DEFS[type] getter_definitions.append( getter_def.substitute(op=info.op, name=name, body=body)) else: # Types we don't expose python bindings to yet: # TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry, # std::vector<std::vector<int64_t>>, std::vector<at::ScalarType> should_append_getsetdef = False if should_append_getsetdef: py_getsetdef_structs.append( PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name)) if should_append_raw_getsetdef: py_getsetdef_structs.append( PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name))