def test_old_lightningmodule_path(): from pytorch_lightning.core.lightning import LightningModule with pytest.deprecated_call( match= "pytorch_lightning.core.lightning.LightningModule has been deprecated in v1.7" " and will be removed in v1.9."): LightningModule()
def sanitize_parameters_to_prune( pl_module: LightningModule, parameters_to_prune: _PARAM_LIST = (), parameter_names: Sequence[str] = () ) -> _PARAM_LIST: """This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. If ``parameters_to_prune is None``, it will be generated with all parameters of the model. Raises: MisconfigurationException: If ``parameters_to_prune`` doesn't exist in the model, or if ``parameters_to_prune`` is neither a list nor a tuple. """ parameters = parameter_names or ModelPruning.PARAMETER_NAMES current_modules = [ m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS) ] if not parameters_to_prune: parameters_to_prune = [(m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None] elif (isinstance(parameters_to_prune, (list, tuple)) and len(parameters_to_prune) > 0 and all(len(p) == 2 for p in parameters_to_prune) and all( isinstance(a, nn.Module) and isinstance(b, str) for a, b in parameters_to_prune)): missing_modules, missing_parameters = [], [] for module, name in parameters_to_prune: if module not in current_modules: missing_modules.append(module) continue if not hasattr(module, name): missing_parameters.append(name) if missing_modules or missing_parameters: raise MisconfigurationException( "Some provided `parameters_to_tune` don't exist in the model." f" Found missing modules: {missing_modules} and missing parameters: {missing_parameters}" ) else: raise MisconfigurationException( "The provided `parameters_to_prune` should either be list of tuple" " with 2 elements: (nn.Module, parameter_name_to_prune) or None" ) return parameters_to_prune
def _make_pruning_permanent_on_state_dict( self, pl_module: LightningModule) -> Dict[str, Any]: state_dict = pl_module.state_dict() # find the mask and the original weights. map_pruned_params = { k.replace("_mask", "") for k in state_dict.keys() if k.endswith("_mask") } for tensor_name in map_pruned_params: orig = state_dict.pop(tensor_name + "_orig") mask = state_dict.pop(tensor_name + "_mask") # make weights permanent state_dict[tensor_name] = mask.to(dtype=orig.dtype) * orig def move_to_cpu(tensor: torch.Tensor) -> torch.Tensor: # each tensor and move them on cpu return tensor.cpu() return apply_to_collection(state_dict, torch.Tensor, move_to_cpu)
def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx)
def on_train_end(self, trainer: Trainer, model: LightningModule): metrics = self.get_metrics(trainer, model) assert metrics["foo"] == self.trainer.current_epoch - 1 assert metrics["foo_2"] == self.trainer.current_epoch - 1 model.callback_on_train_end_called = True