Пример #1
0
def test_old_lightningmodule_path():
    from pytorch_lightning.core.lightning import LightningModule

    with pytest.deprecated_call(
            match=
            "pytorch_lightning.core.lightning.LightningModule has been deprecated in v1.7"
            " and will be removed in v1.9."):
        LightningModule()
Пример #2
0
    def sanitize_parameters_to_prune(
        pl_module: LightningModule,
        parameters_to_prune: _PARAM_LIST = (),
        parameter_names: Sequence[str] = ()
    ) -> _PARAM_LIST:
        """This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. If
        ``parameters_to_prune is None``, it will be generated with all parameters of the model.

        Raises:
            MisconfigurationException:
                If ``parameters_to_prune`` doesn't exist in the model, or
                if ``parameters_to_prune`` is neither a list nor a tuple.
        """
        parameters = parameter_names or ModelPruning.PARAMETER_NAMES

        current_modules = [
            m for m in pl_module.modules()
            if not isinstance(m, _MODULE_CONTAINERS)
        ]

        if not parameters_to_prune:
            parameters_to_prune = [(m, p) for p in parameters
                                   for m in current_modules
                                   if getattr(m, p, None) is not None]
        elif (isinstance(parameters_to_prune,
                         (list, tuple)) and len(parameters_to_prune) > 0
              and all(len(p) == 2 for p in parameters_to_prune) and all(
                  isinstance(a, nn.Module) and isinstance(b, str)
                  for a, b in parameters_to_prune)):
            missing_modules, missing_parameters = [], []
            for module, name in parameters_to_prune:
                if module not in current_modules:
                    missing_modules.append(module)
                    continue
                if not hasattr(module, name):
                    missing_parameters.append(name)

            if missing_modules or missing_parameters:
                raise MisconfigurationException(
                    "Some provided `parameters_to_tune` don't exist in the model."
                    f" Found missing modules: {missing_modules} and missing parameters: {missing_parameters}"
                )
        else:
            raise MisconfigurationException(
                "The provided `parameters_to_prune` should either be list of tuple"
                " with 2 elements: (nn.Module, parameter_name_to_prune) or None"
            )

        return parameters_to_prune
Пример #3
0
    def _make_pruning_permanent_on_state_dict(
            self, pl_module: LightningModule) -> Dict[str, Any]:
        state_dict = pl_module.state_dict()

        # find the mask and the original weights.
        map_pruned_params = {
            k.replace("_mask", "")
            for k in state_dict.keys() if k.endswith("_mask")
        }
        for tensor_name in map_pruned_params:
            orig = state_dict.pop(tensor_name + "_orig")
            mask = state_dict.pop(tensor_name + "_mask")
            # make weights permanent
            state_dict[tensor_name] = mask.to(dtype=orig.dtype) * orig

        def move_to_cpu(tensor: torch.Tensor) -> torch.Tensor:
            # each tensor and move them on cpu
            return tensor.cpu()

        return apply_to_collection(state_dict, torch.Tensor, move_to_cpu)
 def backward(self, loss, optimizer, optimizer_idx):
     return LightningModule.backward(self, loss, optimizer,
                                     optimizer_idx)
Пример #5
0
 def on_train_end(self, trainer: Trainer, model: LightningModule):
     metrics = self.get_metrics(trainer, model)
     assert metrics["foo"] == self.trainer.current_epoch - 1
     assert metrics["foo_2"] == self.trainer.current_epoch - 1
     model.callback_on_train_end_called = True