def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator,
                                     tmpdir):
    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 8))
    model = BoringModel()
    num_epochs = 1
    state_dict = deepcopy(model.state_dict())

    lite = LiteRunner(precision=precision,
                      strategy=strategy,
                      devices=devices,
                      accelerator=accelerator)
    checkpoint_path = lite.run(model,
                               train_dataloader,
                               num_epochs=num_epochs,
                               tmpdir=tmpdir)
    spawn_model_state_dict = torch.load(checkpoint_path)

    for w_pure, w_lite in zip(state_dict.values(),
                              spawn_model_state_dict.values()):
        assert not torch.equal(w_pure.cpu(), w_lite.cpu())

    model.load_state_dict(state_dict)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = str(find_free_network_port())
    mp.spawn(run,
             args=(model, train_dataloader, num_epochs, precision, accelerator,
                   tmpdir),
             nprocs=2)
    spawn_pure_model_state_dict = torch.load(
        os.path.join(tmpdir, "model_spawn.pt"))

    for w_pure, w_lite in zip(spawn_pure_model_state_dict.values(),
                              spawn_model_state_dict.values()):
        assert torch.equal(w_pure.cpu(), w_lite.cpu())
def test_boring_lite_model_ddp(precision, strategy, devices, accelerator,
                               tmpdir):
    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 4))
    model = BoringModel()
    num_epochs = 1
    state_dict = deepcopy(model.state_dict())

    lite = LiteRunner(precision=precision,
                      strategy=strategy,
                      devices=devices,
                      accelerator=accelerator)
    lite.run(model, train_dataloader, num_epochs=num_epochs, tmpdir=tmpdir)

    lite_model_state_dict = model.state_dict()

    for w_pure, w_lite in zip(state_dict.values(),
                              lite_model_state_dict.values()):
        assert not torch.equal(w_pure.cpu(), w_lite.cpu())

    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 4))
    model = BoringModel()
    run(lite.global_rank, model, train_dataloader, num_epochs, precision,
        accelerator, tmpdir)
    pure_model_state_dict = model.state_dict()

    for w_pure, w_lite in zip(pure_model_state_dict.values(),
                              lite_model_state_dict.values()):
        assert torch.equal(w_pure.cpu(), w_lite.cpu())
def test_boring_lite_model_single_device(precision, strategy, devices,
                                         accelerator, tmpdir):
    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 8))
    model = BoringModel()
    num_epochs = 1
    state_dict = deepcopy(model.state_dict())

    lite = LiteRunner(precision=precision,
                      strategy=strategy,
                      devices=devices,
                      accelerator=accelerator)
    lite.run(model, train_dataloader, num_epochs=num_epochs)
    lite_state_dict = model.state_dict()

    with precision_context(precision, accelerator):
        model.load_state_dict(state_dict)
        pure_state_dict = main(lite.to_device,
                               model,
                               train_dataloader,
                               num_epochs=num_epochs)

    state_dict = apply_to_collection(state_dict, torch.Tensor, lite.to_device)
    for w_pure, w_lite in zip(state_dict.values(), lite_state_dict.values()):
        assert not torch.equal(w_pure, w_lite)

    for w_pure, w_lite in zip(pure_state_dict.values(),
                              lite_state_dict.values()):
        assert torch.equal(w_pure, w_lite)