def test_deepspeed_lightning_module(tmpdir): """Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly.""" model = BoringModel() module = LightningDeepSpeedModule(model, precision=16) module.half() assert module.dtype == torch.half assert model.dtype == torch.half module.to(torch.double) assert module.dtype == torch.double assert model.dtype == torch.double
def test_deepspeed_lightning_module_precision(tmpdir): """Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision 16.""" model = BoringModel() module = LightningDeepSpeedModule(model, precision=16) module.cuda().half() assert module.dtype == torch.half assert model.dtype == torch.half x = torch.randn((1, 32), dtype=torch.float).cuda() out = module(x) assert out.dtype == torch.half module.to(torch.double) assert module.dtype == torch.double assert model.dtype == torch.double