def test_ddp_sharded_plugin_finetune(tmpdir): """ Test to ensure that we can save and restart training (simulate fine-tuning) """ model = BoringModel() trainer = Trainer( gpus=2, accelerator='ddp_spawn', plugins=[DDPShardedPlugin()], fast_dev_run=True, ) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, 'model.pt') trainer.save_checkpoint(checkpoint_path) saved_model = BoringModel.load_from_checkpoint(checkpoint_path) trainer = Trainer(fast_dev_run=True, ) trainer.fit(saved_model)
def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): """ Test to ensure that checkpoint is saved correctly """ model = BoringModel() trainer = Trainer( accelerator='ddp_cpu', plugins=[DDPShardedPlugin()], fast_dev_run=True, ) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, 'model.pt') trainer.save_checkpoint(checkpoint_path) saved_model = BoringModel.load_from_checkpoint(checkpoint_path) # Assert model parameters are identical after loading for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): assert torch.equal(ddp_param, shard_param)