예제 #1
0
def test_unwrap_lightning_module():
    model = BoringModel()
    wrapped_model = _LightningPrecisionModuleWrapperBase(model)
    wrapped_model = _LightningModuleWrapperBase(wrapped_model)
    wrapped_model = DataParallel(wrapped_model)

    assert unwrap_lightning_module(wrapped_model) == model
예제 #2
0
    def unwrap_lightning_module_sharded(
            wrapped_model: nn.Module) -> "pl.LightningModule":
        model = wrapped_model
        if isinstance(model, ShardedDataParallel):
            model = model.module

        return unwrap_lightning_module(model)
예제 #3
0
 def lightning_module(self):
     return unwrap_lightning_module(self._model)
 def lightning_module(self) -> LightningModule:
     """Returns the pure LightningModule without potential wrappers"""
     return unwrap_lightning_module(self._model)
예제 #5
0
 def lightning_module(self) -> "pl.LightningModule":
     model = self._model
     if isinstance(model, BaguaDistributedDataParallel):
         model = model.module
     return unwrap_lightning_module(model)  # type: ignore[arg-type]
예제 #6
0
 def lightning_module(self) -> Optional["pl.LightningModule"]:
     return unwrap_lightning_module(self.model) if self.model is not None else None
예제 #7
0
 def lightning_module(self) -> Optional["pl.LightningModule"]:
     """Returns the pure LightningModule without potential wrappers."""
     return unwrap_lightning_module(self.model) if self.model is not None else None
예제 #8
0
 def lightning_module(self) -> Optional["pl.LightningModule"]:
     model = self.model
     if isinstance(model, BaguaDistributedDataParallel):
         model = model.module
     return unwrap_lightning_module(model) if model is not None else None