def test_ddp_sharded_plugin_finetune(tmpdir, ray_start_2_cpus, seed): """Tests if we can save and restart training.""" model = BoringModel() trainer = Trainer( plugins=[RayShardedPlugin(num_workers=2)], fast_dev_run=True, ) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) saved_model = BoringModel.load_from_checkpoint(checkpoint_path) trainer = Trainer(fast_dev_run=True, ) trainer.fit(saved_model)
def test_early_stop(tmpdir, ray_start_2_cpus): """Tests if early stopping callback works correctly.""" model = BoringModel() accelerator = RayAccelerator(num_workers=1, use_gpu=False) early_stop = EarlyStopping(monitor="val_loss", patience=2, verbose=True) trainer = get_trainer(tmpdir, max_epochs=500, accelerator=accelerator, callbacks=[early_stop], limit_train_batches=1.0, limit_val_batches=1.0, progress_bar_refresh_rate=1) trainer.fit(model) trained_model = BoringModel.load_from_checkpoint( trainer.checkpoint_callback.best_model_path) assert trained_model.val_epoch == 2, trained_model.val_epoch
def test_ddp_sharded_plugin_checkpoint(tmpdir, ray_start_2_cpus, seed): """Tests if checkpoint is saved correctly.""" model = BoringModel() trainer = Trainer( plugins=[RayShardedPlugin(num_workers=2)], fast_dev_run=True, ) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) saved_model = BoringModel.load_from_checkpoint(checkpoint_path) # Assert model parameters are identical after loading. for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): assert torch.equal(ddp_param, shard_param)