def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator,
                                     tmpdir):
    LightningLite.seed_everything(42)
    train_dataloader = DataLoader(RandomDataset(32, 8))
    model = BoringModel()
    num_epochs = 1
    state_dict = deepcopy(model.state_dict())

    lite = LiteRunner(precision=precision,
                      strategy=strategy,
                      devices=devices,
                      accelerator=accelerator)
    checkpoint_path = lite.run(model,
                               train_dataloader,
                               num_epochs=num_epochs,
                               tmpdir=tmpdir)
    spawn_model_state_dict = torch.load(checkpoint_path)

    for w_pure, w_lite in zip(state_dict.values(),
                              spawn_model_state_dict.values()):
        assert not torch.equal(w_pure.cpu(), w_lite.cpu())

    model.load_state_dict(state_dict)
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = str(find_free_network_port())
    mp.spawn(run,
             args=(model, train_dataloader, num_epochs, precision, accelerator,
                   tmpdir),
             nprocs=2)
    spawn_pure_model_state_dict = torch.load(
        os.path.join(tmpdir, "model_spawn.pt"))

    for w_pure, w_lite in zip(spawn_pure_model_state_dict.values(),
                              spawn_model_state_dict.values()):
        assert torch.equal(w_pure.cpu(), w_lite.cpu())
def test_raise_when_peer_endpoint_unsuccessful(caplog):
    port = find_free_network_port()
    with pytest.raises(MisconfigurationException, match="Unable to get peers"):
        with mock.patch("requests.get", wraps=requests.get) as requests_mock:
            CollaborativeStrategy(
                target_batch_size=1,
                peer_endpoint=f"localhost:{port}",
                retry_endpoint_attempts=10,
                retry_endpoint_sleep_duration=0,
            )
    assert "Failed to get peers, retrying" in caplog.text
    assert requests_mock.call_count == 10
def single_process_pg():
    """Initialize the default process group with only the current process for testing purposes.

    The process group is destroyed when the with block is exited.
    """
    if torch.distributed.is_initialized():
        raise RuntimeError("Can't use `single_process_pg` when the default process group is already initialized.")

    orig_environ = os.environ.copy()
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(find_free_network_port())
    os.environ["RANK"] = "0"
    os.environ["WORLD_SIZE"] = "1"
    torch.distributed.init_process_group("gloo")
    try:
        yield
    finally:
        torch.distributed.destroy_process_group()
        os.environ.clear()
        os.environ.update(orig_environ)
@mock.patch("pytorch_lightning.strategies.collaborative.DHTManager._get_peers", autospec=True)
@pytest.mark.parametrize(
    "initial_peers,peer_endpoint",
    [(["TEST"], None), (None, "localhost:153")],
)
def test_logging_disabled_when_second_peer(mock_dht, mock_http, initial_peers, peer_endpoint):
    """Test when we are a second peer (passing initial peers or peer endpoint) we warn the user that
    logging/checkpointing will be disabled."""
    with pytest.warns(UserWarning, match="This machine is not a persistent machine"):
        CollaborativeStrategy(target_batch_size=1, initial_peers=initial_peers, peer_endpoint=peer_endpoint)


@RunIf(hivemind=True)
@mock.patch.dict(
    os.environ,
    {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor", "PL_PORT": str(find_free_network_port())},
    clear=True,
)
@pytest.mark.parametrize(
    "endpoint,expected_message",
    [(False, "INITIAL_PEERS"), (True, "Sidecar endpoint enabled to serve peers.")],
)
def test_initial_peer_message(caplog, endpoint, expected_message):
    model = BoringModel()
    trainer = pl.Trainer(strategy=CollaborativeStrategy(target_batch_size=1, endpoint=endpoint), fast_dev_run=True)
    trainer.fit(model)
    assert expected_message in caplog.text


@RunIf(hivemind=True)
@mock.patch.dict(os.environ, {"HIVEMIND_MEMORY_SHARING_STRATEGY": "file_descriptor"}, clear=True)