def _inject_new_class(module: Module) -> None: r"""Sets up a module to be parametrized. This works by substituting the class of the module by a class that extends it to be able to inject a property Args: module (nn.Module): module into which to inject the property """ cls = module.__class__ def getstate(self): raise RuntimeError( "Serialization of parametrized modules is only " "supported through state_dict(). See:\n" "https://pytorch.org/tutorials/beginner/saving_loading_models.html" "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" ) param_cls = type( f"Parametrized{cls.__name__}", (cls,), { "__getstate__": getstate, }, ) module.__class__ = param_cls
def _inject_new_class(module: Module) -> None: r"""Sets up a module to be parametrized. This works by substituting the class of the module by a class that extends it to be able to inject a property Args: module (nn.Module): module into which to inject the property """ cls = module.__class__ def default_deepcopy(self, memo): # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class. obj = memo.get(id(self), None) if obj is not None: return obj replica = self.__new__(self.__class__) memo[id(self)] = replica replica.__dict__ = deepcopy(self.__dict__, memo) # Also save all slots if they exist. slots_to_save = copyreg._slotnames( self.__class__) # type: ignore[attr-defined] for slot in slots_to_save: if hasattr(self, slot): setattr(replica, slot, deepcopy(getattr(self, slot), memo)) return replica def getstate(self): raise RuntimeError( "Serialization of parametrized modules is only " "supported through state_dict(). See:\n" "https://pytorch.org/tutorials/beginner/saving_loading_models.html" "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" ) dct = {"__getstate__": getstate} # We don't allow serialization of parametrized modules but should still allow deepcopying. # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists. if not hasattr(cls, "__deepcopy__"): dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment] param_cls = type( f"Parametrized{cls.__name__}", (cls, ), dct, ) module.__class__ = param_cls
def remove_parametrizations( module: Module, tensor_name: str, leave_parametrized: bool = True ) -> Module: r"""Removes the parametrizations on a tensor in a module. - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to its current output. In this case, the parametrization shall not change the ``dtype`` of the tensor. - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to the unparametrised tensor in ``module.parametrizations[tensor_name].original``. This is only possible when the parametrization depends on just one tensor. Args: module (nn.Module): module from which remove the parametrization tensor_name (str): name of the parametrization to be removed leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. Default: ``True`` Returns: Module: module Raises: ValueError: if ``module[tensor_name]`` is not parametrized ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors """ if not is_parametrized(module, tensor_name): raise ValueError( f"Module {module} does not have a parametrization on {tensor_name}" ) # Fetch the original tensor assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy parametrizations = module.parametrizations[tensor_name] if parametrizations.is_tensor: original = parametrizations.original if leave_parametrized: with torch.no_grad(): t = getattr(module, tensor_name) # We know they have the same dtype because we have checked this when registering the # parametrizations. As such, we can use set_ # We do this so that the parameter does not to change the id() # This way the user does not need to update the optimizer with torch.no_grad(): original.set_(t) else: if leave_parametrized: # We cannot use no_grad because we need to know whether one or more # original tensors required grad t = getattr(module, tensor_name) # We'll have to trust the user to add it to the optimizer original = Parameter(t) if t.requires_grad else t else: raise ValueError( "Cannot leave unparametrized (`leave_parametrized=False`) a tensor " "that is parametrized in terms of a sequence of tensors." ) # Delete the property that manages the parametrization delattr(module.__class__, tensor_name) # Delete the ParametrizationList del module.parametrizations[tensor_name] # Restore the parameter / buffer into the main class _register_parameter_or_buffer(module, tensor_name, original) # Roll back the parametrized class if no other buffer or parameter # is currently parametrized in this class if not is_parametrized(module): delattr(module, "parametrizations") # Restore class orig_cls = module.__class__.__bases__[0] module.__class__ = orig_cls return module
def remove_parametrizations(module: Module, tensor_name: str, leave_parametrized: bool = True) -> Module: r"""Removes the parametrizations on a tensor in a module. - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to its current output. In this case, the parametrization shall not change the ``dtype`` of the tensor. - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to the unparametrised tensor in ``module.parametrizations[tensor_name].original``. Args: module (nn.Module): module from which remove the parametrization tensor_name (str): name of the parametrization to be removed leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. Default: ``True`` Returns: Module: module Raises: ValueError: if ``module[tensor_name]`` is not parametrized ValueError: if ``leave_parametrized=True`` and the parametrization changes the size or dtype of the tensor """ if not is_parametrized(module, tensor_name): raise ValueError( "Module {} does not have a parametrization on {}".format( module, tensor_name)) # Fetch the original tensor original = module.parametrizations[tensor_name].original # type: ignore if leave_parametrized: t = getattr(module, tensor_name) # If they have the same dtype, we reuse the original tensor. # We do this so that the parameter does not to change the id() # This way the user does not need to update the optimizer if t.dtype == original.dtype: with torch.no_grad(): original.set_(t) else: raise ValueError( "The parametrization changes the dtype of the tensor from {} to {}. " "It is not supported to leave the tensor parametrized (`leave_parametrized=True`) " "in this case.".format(original.dtype, t.dtype)) # Delete the property that manages the parametrization delattr(module.__class__, tensor_name) # Delete the ParametrizationList del module.parametrizations[tensor_name] # type: ignore # Restore the parameter / buffer into the main class if isinstance(original, Parameter): module.register_parameter(tensor_name, original) else: module.register_buffer(tensor_name, original) # Roll back the parametrized class if no other buffer or parameter # is currently parametrized in this class if not is_parametrized(module): delattr(module, "parametrizations") # Restore class orig_cls = module.__class__.__bases__[0] module.__class__ = orig_cls return module