def test_deepspeed_lightning_module(tmpdir):
    """Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly."""

    model = BoringModel()
    module = LightningDeepSpeedModule(model, precision=16)

    module.half()
    assert module.dtype == torch.half
    assert model.dtype == torch.half

    module.to(torch.double)
    assert module.dtype == torch.double
    assert model.dtype == torch.double
def test_deepspeed_lightning_module_precision(tmpdir):
    """Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision
    16."""

    model = BoringModel()
    module = LightningDeepSpeedModule(model, precision=16)

    module.cuda().half()
    assert module.dtype == torch.half
    assert model.dtype == torch.half

    x = torch.randn((1, 32), dtype=torch.float).cuda()
    out = module(x)

    assert out.dtype == torch.half

    module.to(torch.double)
    assert module.dtype == torch.double
    assert model.dtype == torch.double