def test_lightning_parallel_module_device_access(nest, unnest):
    """ Test that self.device returns the correct value in replicas of DataParallel. """
    class DeviceAccessModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.layer = nn.Linear(2, 3)

        @auto_move_data
        def training_step(self, batch, batch_idx):
            batch = unnest(batch)
            assert batch.shape == torch.Size([1, 1])
            assert self.device.index == batch.item()
            assert self.device == self.layer.weight.device
            return torch.tensor(1, device=self.device)

    pl_module = DeviceAccessModel()
    # required for redirecting the forward call to training_step
    pl_module.trainer = Mock()
    pl_module.trainer.state.stage = RunningStage.TRAINING

    root_device = torch.device("cuda", 0)
    wrapped_module = LightningParallelModule(pl_module).to(root_device)
    model = DataParallel(wrapped_module, device_ids=[0, 1])

    data = torch.tensor([0.0, 1.0],
                        device=root_device).view(2, 1)  # one value per gpu
    data = data.to(root_device)
    data = nest(data)
    output = model(data, 0)
    assert output.device == root_device
    assert pl_module.device == root_device
    assert torch.all(output.cpu().eq(torch.tensor([1, 1])))
def test_lightning_parallel_module_unsqueeze_scalar():
    """ Test that LightningParallelModule takes care of un-squeezeing 0-dim tensors. """
    class TestModel(BoringModel):
        def training_step(self, batch, batch_idx):
            output = super().training_step(batch, batch_idx)
            loss = output["loss"]
            loss = loss.squeeze()
            assert loss.dim() == 0
            # PyTorch usually warns about 0-dim tensors returned in DP
            return {"loss": loss}

    model = TestModel()
    model.trainer = Mock()
    model.trainer.state.stage = RunningStage.TRAINING
    batch = torch.rand(2, 32).cuda()
    batch_idx = 0

    wrapped_model = LightningParallelModule(model).cuda()
    dp_module = DataParallel(wrapped_model, device_ids=[0, 1])

    output = wrapped_model(batch, batch_idx)
    assert output["loss"].dim() == 1

    with pytest.warns(None) as record:
        output = dp_module(batch, batch_idx)

    assert output["loss"].dim() == 1
    assert not record
Example #3
0
 def setup(self, trainer: "pl.Trainer") -> None:
     # model needs to be moved to the device before it is wrapped
     self.model_to_device()
     assert isinstance(
         self.model,
         (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
     self.model = self._setup_model(LightningParallelModule(self.model))
     super().setup(trainer)
Example #4
0
    def __init_torch_data_parallel(self, model):
        # create list of device ids
        device_ids = self.trainer.data_parallel_device_ids
        if isinstance(device_ids, int):
            device_ids = list(range(device_ids))

        # set dp device
        torch.cuda.set_device(self.trainer.root_gpu)
        model = torch.nn.DataParallel(LightningParallelModule(model),
                                      device_ids=device_ids)
        return model
def test_lightning_parallel_module_python_scalar_conversion(device):
    """ Test that LightningParallelModule can convert Python scalars to tensors. """
    class TestModel(BoringModel):
        def training_step(self, batch, batch_idx):
            output = super().training_step(batch, batch_idx)
            # PyTorch DP does not support Python scalars, Lightning converts them to tensors
            output.update({"python scalar": 12.3})
            return output

    model = TestModel().to(device)
    model.trainer = Mock()
    model.trainer.state.stage = RunningStage.TRAINING
    batch = torch.rand(2, 32).to(device)
    batch_idx = 0

    wrapped_model = LightningParallelModule(model)
    output = wrapped_model(batch, batch_idx)
    assert output["python scalar"] == torch.tensor([12.3], device=device)
def test_lightning_parallel_module_device_access_warning():
    """ Test that we show a warning when the device can't be inferred from the input. """
    class DeviceAccessModel(LightningModule):
        def training_step(self, batch, batch_idx):
            pass

    pl_module = DeviceAccessModel()
    # required for redirecting the forward call to training_step
    pl_module.trainer = Mock()
    pl_module.trainer.state.stage = RunningStage.TRAINING

    wrapped_module = LightningParallelModule(pl_module).cuda()
    model = DataParallel(wrapped_module, device_ids=[0, 1])

    data = dict(x=1)  # contains no tensors
    with pytest.warns(
            UserWarning,
            match="Could not determine on which device the inputs are."):
        _ = model(data, 0)
Example #7
0
 def setup(self, model):
     # model needs to be moved to the device before it is wrapped
     model.to(self.root_device)
     self._model = DataParallel(LightningParallelModule(model),
                                self.parallel_devices)
Example #8
0
 def setup(self, trainer: "pl.Trainer") -> None:
     # model needs to be moved to the device before it is wrapped
     self.model_to_device()
     self.model = self._setup_model(LightningParallelModule(self.model))
     super().setup(trainer)
Example #9
0
 def setup(self) -> None:
     # model needs to be moved to the device before it is wrapped
     self.model_to_device()
     self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices)
Example #10
0
 def setup(self, model):
     self._model = DataParallel(LightningParallelModule(model),
                                self.parallel_devices)