def test_deepspeed_collate_checkpoint(tmpdir):
    """Test to ensure that with DeepSpeed Stage 3 we can collate the sharded checkpoints into a single file."""
    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir,
                      plugins=[DeepSpeedPlugin(stage=3)],
                      gpus=2,
                      fast_dev_run=True,
                      precision=16)
    trainer.fit(model)
    checkpoint_path = os.path.join(tmpdir, "model.pt")
    checkpoint_path = trainer.accelerator.broadcast(checkpoint_path)
    trainer.save_checkpoint(checkpoint_path)
    trainer.accelerator.barrier()
    if trainer.is_global_zero:
        # ensure function call works
        output_path = os.path.join(tmpdir, "single_model.pt")
        convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path,
                                                   output_path)
        _assert_checkpoint_equal(model, output_path)
Exemplo n.º 2
0
def test_deepspeed_collate_checkpoint(tmpdir):
    """Test to ensure that with DeepSpeed Stage 3 we can collate the sharded checkpoints into a single file."""
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        strategy=DeepSpeedStrategy(stage=3),
        accelerator="gpu",
        devices=2,
        fast_dev_run=True,
        precision=16,
        enable_progress_bar=False,
        enable_model_summary=False,
    )
    trainer.fit(model)
    checkpoint_path = os.path.join(tmpdir, "model.pt")
    checkpoint_path = trainer.strategy.broadcast(checkpoint_path)
    trainer.save_checkpoint(checkpoint_path)
    trainer.strategy.barrier()
    if trainer.is_global_zero:
        # ensure function call works
        output_path = os.path.join(tmpdir, "single_model.pt")
        convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path,
                                                   output_path)
        _assert_checkpoint_equal(model, output_path)