def run_checkpoint_test(tmpdir: str,
                        save_full_weights: bool,
                        automatic_optimization: bool = True,
                        accumulate_grad_batches: int = 2):
    seed_everything(1)
    if automatic_optimization:
        model = ModelParallelClassificationModel()
    else:
        model = ManualModelParallelClassificationModel()
    dm = ClassifDataModule()
    ck = ModelCheckpoint(monitor="val_acc",
                         mode="max",
                         save_last=True,
                         save_top_k=-1)
    trainer = Trainer(
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=10,
        plugins=[
            DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)
        ],
        gpus=2,
        precision=16,
        accumulate_grad_batches=accumulate_grad_batches,
        callbacks=[ck],
    )
    trainer.fit(model, datamodule=dm)

    results = trainer.test(model, datamodule=dm)
    assert results[0]["test_acc"] > 0.7

    saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm)
    assert saved_results[0]["test_acc"] > 0.7
    assert saved_results == results

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=10,
        plugins=[
            DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)
        ],
        gpus=2,
        precision=16,
        accumulate_grad_batches=2,
        callbacks=[ck],
        resume_from_checkpoint=ck.best_model_path,
    )
    results = trainer.test(model, datamodule=dm)
    assert results[0]["test_acc"] > 0.7

    dm.predict_dataloader = dm.test_dataloader
    results = trainer.predict(datamodule=dm)
    assert results[-1] > 0.7
def run_checkpoint_test(tmpdir, save_full_weights):
    seed_everything(42)
    model = ModelParallelClassificationModel()
    dm = ClassifDataModule()
    ck = ModelCheckpoint(monitor="val_acc",
                         mode="max",
                         save_last=True,
                         save_top_k=-1)
    trainer = Trainer(max_epochs=10,
                      plugins=[
                          DeepSpeedPlugin(stage=3,
                                          save_full_weights=save_full_weights)
                      ],
                      default_root_dir=tmpdir,
                      gpus=2,
                      precision=16,
                      accumulate_grad_batches=2,
                      callbacks=[ck])
    trainer.fit(model, datamodule=dm)

    results = trainer.test(model, datamodule=dm)
    assert results[0]['test_acc'] > 0.7

    saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm)
    assert saved_results[0]['test_acc'] > 0.7
    assert saved_results == results

    trainer = Trainer(max_epochs=10,
                      plugins=[
                          DeepSpeedPlugin(stage=3,
                                          save_full_weights=save_full_weights)
                      ],
                      default_root_dir=tmpdir,
                      gpus=2,
                      precision=16,
                      accumulate_grad_batches=2,
                      callbacks=[ck],
                      resume_from_checkpoint=ck.best_model_path)
    results = trainer.test(model, datamodule=dm)
    assert results[0]['test_acc'] > 0.7

    dm.predict_dataloader = dm.test_dataloader
    results = trainer.predict(datamodule=dm)
    assert results[-1] > 0.7
def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
    """
    Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
    and see convergence.
    """
    seed_everything(42)
    model = ModelParallelClassificationModel()
    dm = ClassifDataModule()
    ck = ModelCheckpoint(monitor="val_acc",
                         mode="max",
                         save_last=True,
                         save_top_k=-1)
    trainer = Trainer(max_epochs=10,
                      plugins=[DeepSpeedPlugin(stage=3)],
                      default_root_dir=tmpdir,
                      gpus=2,
                      precision=16,
                      accumulate_grad_batches=2,
                      callbacks=[ck])
    trainer.fit(model, datamodule=dm)

    results = trainer.test(model, datamodule=dm)
    assert results[0]['test_acc'] > 0.7

    saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm)
    assert saved_results[0]['test_acc'] > 0.7
    assert saved_results == results

    trainer = Trainer(max_epochs=10,
                      plugins=[DeepSpeedPlugin(stage=3)],
                      default_root_dir=tmpdir,
                      gpus=2,
                      precision=16,
                      accumulate_grad_batches=2,
                      callbacks=[ck],
                      resume_from_checkpoint=ck.best_model_path)
    results = trainer.test(model, datamodule=dm)
    assert results[0]['test_acc'] > 0.7

    dm.predict_dataloader = dm.test_dataloader
    results = trainer.predict(datamodule=dm)
    assert results[-1] > 0.7