def test_ddp_sharded_plugin_finetune(tmpdir, ray_start_2_cpus, seed):
    """Tests if we can save and restart training."""
    model = BoringModel()
    trainer = Trainer(
        plugins=[RayShardedPlugin(num_workers=2)],
        fast_dev_run=True,
    )
    trainer.fit(model)

    checkpoint_path = os.path.join(tmpdir, "model.pt")
    trainer.save_checkpoint(checkpoint_path)
    saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

    trainer = Trainer(fast_dev_run=True, )
    trainer.fit(saved_model)
def test_early_stop(tmpdir, ray_start_2_cpus):
    """Tests if early stopping callback works correctly."""
    model = BoringModel()
    accelerator = RayAccelerator(num_workers=1, use_gpu=False)
    early_stop = EarlyStopping(monitor="val_loss", patience=2, verbose=True)
    trainer = get_trainer(tmpdir,
                          max_epochs=500,
                          accelerator=accelerator,
                          callbacks=[early_stop],
                          limit_train_batches=1.0,
                          limit_val_batches=1.0,
                          progress_bar_refresh_rate=1)
    trainer.fit(model)
    trained_model = BoringModel.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path)
    assert trained_model.val_epoch == 2, trained_model.val_epoch
def test_ddp_sharded_plugin_checkpoint(tmpdir, ray_start_2_cpus, seed):
    """Tests if checkpoint is saved correctly."""
    model = BoringModel()
    trainer = Trainer(
        plugins=[RayShardedPlugin(num_workers=2)],
        fast_dev_run=True,
    )

    trainer.fit(model)

    checkpoint_path = os.path.join(tmpdir, "model.pt")
    trainer.save_checkpoint(checkpoint_path)
    saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

    # Assert model parameters are identical after loading.
    for ddp_param, shard_param in zip(model.parameters(),
                                      saved_model.parameters()):
        assert torch.equal(ddp_param, shard_param)