示例#1
0
    def test(self, stage=2):
        if not bf16_required_version_check():
            pytest.skip(
                " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
            )

        config_dict = {
            "train_batch_size": 1,
            "steps_per_print": 1,
            "fp16": {
                "enabled": False
            },
            "bf16": {
                "enabled": True
            },
            "zero_optimization": {
                "stage": stage
            }
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim)
        optimizer = torch.optim.Adam(model.parameters())
        model, _, _, _ = deepspeed.initialize(config=config_dict,
                                              model=model,
                                              optimizer=optimizer)
        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.bfloat16)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
示例#2
0
    def test(self, zero_stage=2, use_cpu_offload=False):
        if not bf16_required_version_check():
            pytest.skip(
                " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
            )

        if use_cpu_offload and not deepspeed.ops.__compatible_ops__[
                CPUAdamBuilder.NAME]:
            pytest.skip("cpu-adam is not compatible")

        config_dict = {
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                }
            },
            "scheduler": {
                "type": "OneCycle",
                "params": {
                    "cycle_first_step_size": 16000,
                    "cycle_first_stair_count": 8000,
                    "decay_step_size": 16000,
                    "cycle_min_lr": 1e-06,
                    "cycle_max_lr": 3e-05,
                    "decay_lr_rate": 1e-07,
                    "cycle_min_mom": 0.85,
                    "cycle_max_mom": 0.99,
                    "decay_mom_rate": 0.0
                }
            },
            "fp16": {
                "enabled": False
            },
            "bf16": {
                "enabled": True
            },
            "zero_optimization": {
                "stage": zero_stage,
                "cpu_offload": use_cpu_offload
            }
        }

        hidden_dim = 10
        model = SimpleModel(hidden_dim)
        model, _, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.bfloat16)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
示例#3
0
    def test(self, zero_stage=2, use_cpu_offload=False):
        if not bf16_required_version_check():
            pytest.skip(
                " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
            )

        if use_cpu_offload and not deepspeed.ops.__compatible_ops__[
                CPUAdamBuilder.NAME]:
            pytest.skip("cpu-adam is not compatible")

        if zero_stage == 3:
            pytest.skip("skip for now")

        config_dict = {
            "train_micro_batch_size_per_gpu": 1,
            "gradient_accumulation_steps": 1,
            "fp16": {
                "enabled": False
            },
            "bf16": {
                "enabled": True
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                }
            },
            "zero_optimization": {
                "stage": zero_stage,
                "cpu_offload": use_cpu_offload,
                "reduce_bucket_size": 100,
                "allgather_bucket_size": 100
            }
        }

        hidden_dim = 1
        model = SimpleModel(hidden_dim)

        # Ensure model has 2 parameters, to cause empty partition with DP=3
        assert len(list(model.parameters())) == 2
        model, _, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=model.parameters())

        # Now make sure things work..
        data_loader = random_dataloader(model=model,
                                        total_samples=1,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.bfloat16)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
示例#4
0
    def test(self, comp_type, comm_type):
        if comp_type == torch.bfloat16 or comm_type == torch.bfloat16:
            if not bf16_required_version_check():
                pytest.skip(
                    " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
                )

        type_str = {torch.float16: "fp16", torch.bfloat16: "bfp16"}

        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "fp16": {
                "enabled": comp_type == torch.float16
            },
            "bf16": {
                "enabled": comp_type == torch.bfloat16
            },
            "zero_optimization": {
                "stage": 2
            },
            "communication_data_type": type_str[comm_type]
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim)
        optimizer = torch.optim.Adam(model.parameters())
        model, _, _, _ = deepspeed.initialize(config=config_dict,
                                              model=model,
                                              optimizer=optimizer)
        data_loader = random_dataloader(model=model,
                                        total_samples=2,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=comp_type)

        def custom_reduce(tensor,
                          dst,
                          op=dist.ReduceOp.SUM,
                          group=None,
                          async_op=False):
            assert tensor.dtype == comm_type
            return orig_torch_reduce(tensor, dst, op, group, async_op)

        orig_torch_reduce = dist.reduce
        dist.reduce = custom_reduce
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
        dist.reduce = orig_torch_reduce
示例#5
0
    def test_overflow(self, tmpdir):
        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "OneBitLamb",
                "params": {
                    "lr": 0.00015,
                    "weight_decay": 0.01,
                    "max_coeff": 0.3,
                    "min_coeff": 0.01,
                    "freeze_step": 2,
                    "cuda_aware": False,
                    "comm_backend_name": "nccl",
                    "coeff_beta": 0.9,
                    "factor_max": 1.0,
                    "factor_min": 0.5,
                    "factor_threshold": 0.1,
                },
            },
            "gradient_clipping": 1.0,
            "fp16": {
                "enabled": True,
                "loss_scale": 0,
                "initial_scale_power": 16
            },
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim)
        model, _, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=100,
                                        hidden_dim=hidden_dim,
                                        device=model.device)
        save_folder = os.path.join(tmpdir, "saved_checkpoint")
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            if dist.get_rank() == 0 and n >= 10:
                loss = loss * 1000000.0
            model.backward(loss)
            dist.barrier()
            model.step()
            dist.barrier()
            model.save_checkpoint(save_folder, tag=None)
示例#6
0
    def test(self):
        if not bf16_required_version_check():
            pytest.skip(
                " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
            )

        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.00015
                }
            },
            "gradient_clipping": 1.0,
            "zero_optimization": {
                "stage": 2,
                "contiguous_gradients": True,
                "allgather_bucket_size": 2000000000,
                "reduce_bucket_size": 200000000,
                "overlap_comm": False,
                "reduce_scatter": False
            },
            "fp16": {
                "enabled": False
            },
            "bf16": {
                "enabled": True
            }
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim)
        model, _, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.bfloat16)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
示例#7
0
    def test(self, dtype):
        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "OneBitLamb",
                "params": {
                    "lr": 0.00015,
                    "weight_decay": 0.01,
                    "max_coeff": 0.3,
                    "min_coeff": 0.01,
                    "freeze_step": 2,
                    "cuda_aware": False,
                    "comm_backend_name": "nccl",
                    "coeff_beta": 0.9,
                    "factor_max": 1.0,
                    "factor_min": 0.5,
                    "factor_threshold": 0.1,
                },
            },
            "gradient_clipping": 1.0,
            "fp16": {
                "enabled": (dtype == torch.float16),
                "loss_scale": 0,
                "initial_scale_power": 16,
            },
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim)
        model, _, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=model.parameters())
        data_loader = random_dataloader(
            model=model,
            total_samples=50,
            hidden_dim=hidden_dim,
            device=model.device,
            dtype=dtype,
        )
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
示例#8
0
    def test(self):
        config_dict = {
            "train_batch_size": 1,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.001,
                }
            },
            "zero_optimization": {
                "stage": 0
            },
            "fp16": {
                "enabled": True,
            },
            "flops_profiler": {
                "enabled": True,
                "step": 1,
                "module_depth": -1,
                "top_modules": 3,
            },
        }
        hidden_dim = 10
        model = SimpleModel(hidden_dim, empty_grad=False)

        model, _, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.half)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
            if n == 3: break
        assert within_range(model.flops_profiler.flops,
                            200,
                            tolerance=TOLERANCE)
        assert model.flops_profiler.params == 110
示例#9
0
    def test_pld_model(self, theta):
        gamma = 0.001
        config_dict = {
            "train_batch_size": 1,
            "steps_per_print": 1,
            "optimizer": {
                "type": 'Adam',
                "params": {
                    "lr": 0.0001
                }
            },
            "fp16": {
                "enabled": True
            },
            "progressive_layer_drop": {
                "enabled": True,
                "theta": theta,
                "gamma": gamma
            }
        }
        hidden_dim = 10

        model = PLD_SimpleModel(hidden_dim, empty_grad=False)
        model, _, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for i, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

            expected_theta = (1. - theta) * np.exp(-gamma * i) + theta
            actual_theta = model.get_pld_theta()
            assert expected_theta == actual_theta
示例#10
0
    def test_non_pld_model(self):
        gamma = 0.001
        theta = 0.5
        config_dict = {
            "train_batch_size": 1,
            "steps_per_print": 1,
            "optimizer": {
                "type": 'Adam',
                "params": {
                    "lr": 0.0001
                }
            },
            "fp16": {
                "enabled": True
            },
            "progressive_layer_drop": {
                "enabled": True,
                "theta": theta,
                "gamma": gamma
            }
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim, empty_grad=False)
        model, _, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=1,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for i, batch in enumerate(data_loader):
            with pytest.raises(TypeError):
                loss = model(batch[0], batch[1])
示例#11
0
    def test(self, tmpdir):
        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "OneBitLamb",
                "params": {
                    "lr": 0.00015,
                    "weight_decay": 0.01,
                    "max_coeff": 0.3,
                    "min_coeff": 0.01,
                    "freeze_step": 2,
                    "cuda_aware": False,
                    "comm_backend_name": "nccl",
                    "coeff_beta": 0.9,
                    "factor_max": 1.0,
                    "factor_min": 0.5,
                    "factor_threshold": 0.1,
                },
            },
            "gradient_clipping": 1.0,
            "fp16": {
                "enabled": True,
                "loss_scale": 0,
                "initial_scale_power": 16
            },
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim)
        param_optimizer = list(model.named_parameters())
        mask1 = torch.zeros_like(param_optimizer[0][1].data)
        mask2 = torch.zeros_like(param_optimizer[0][1].data)
        for col in range(mask1.size()[1]):
            mask1[0][col] += 1
            mask2[1][col] += 1

        optimizer_grouped_parameters_1 = [
            {
                "params": [param_optimizer[0][1]],
                "weight_decay": 0.01,
                "exp_avg_mask": mask1,
            },
            {
                "params": [param_optimizer[1][1]],
                "weight_decay": 0.01
            },
        ]

        optimizer_grouped_parameters_2 = [
            {
                "params": [param_optimizer[0][1]],
                "weight_decay": 0.01,
                "exp_avg_mask": mask2,
            },
            {
                "params": [param_optimizer[1][1]],
                "weight_decay": 0.01
            },
        ]

        optimizer_grouped_parameters_3 = [
            {
                "params": [param_optimizer[0][1]],
                "weight_decay": 0.01
            },
            {
                "params": [param_optimizer[1][1]],
                "weight_decay": 0.01
            },
        ]

        model_1, optimizer_1, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=optimizer_grouped_parameters_1,
        )
        data_loader = random_dataloader(
            model=model_1,
            total_samples=10,
            hidden_dim=hidden_dim,
            device=model_1.device,
        )
        for n, batch in enumerate(data_loader):
            loss = model_1(batch[0], batch[1])
            model_1.backward(loss)
            model_1.step()
        # Test whether momentum mask still exist after saving checkpoint
        assert optimizer_1.optimizer.lamb_freeze_key is True
        mask1 = mask1.to(
            device=optimizer_1.param_groups[0]["exp_avg_mask"].device)
        assert torch.allclose(optimizer_1.param_groups[0]["exp_avg_mask"],
                              mask1,
                              atol=1e-07), f"Incorrect momentum mask"
        scaling_coeff_1 = []
        for v in optimizer_1.state.values():
            assert "scaling_coeff" in v, f"Incorrect scaling_coeff"
            scaling_coeff_1.append(v["scaling_coeff"])
        save_folder = os.path.join(tmpdir, "saved_checkpoint")
        model_1.save_checkpoint(save_folder, tag=None)
        assert torch.allclose(
            optimizer_1.param_groups[0]["exp_avg_mask"], mask1, atol=1e-07
        ), f"Momentum mask should not change after saving checkpoint"

        model_2, optimizer_2, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=optimizer_grouped_parameters_2,
        )
        # Test whether momentum mask stays the same after loading checkpoint
        mask2 = mask2.to(
            device=optimizer_2.param_groups[0]["exp_avg_mask"].device)
        assert torch.allclose(optimizer_2.param_groups[0]["exp_avg_mask"],
                              mask2,
                              atol=1e-07), f"Incorrect momentum mask"
        model_2.load_checkpoint(
            save_folder,
            tag=None,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        assert torch.allclose(
            optimizer_2.param_groups[0]["exp_avg_mask"], mask2, atol=1e-07
        ), f"Momentum mask should not change after loading checkpoint"
        # Test whether worker&server error is reset
        assert len(optimizer_2.optimizer.worker_errors
                   ) == 0, f"Incorrect worker error"
        assert len(optimizer_2.optimizer.server_errors
                   ) == 0, f"Incorrect server error"
        # Test whether scaling_coeffs is loaded correctly
        scaling_coeff_2 = []
        for v in optimizer_2.state.values():
            assert "scaling_coeff" in v, f"Incorrect scaling_coeff"
            scaling_coeff_2.append(v["scaling_coeff"])
        assert list(sorted(scaling_coeff_2)) == list(
            sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs"
        assert optimizer_2.optimizer.lamb_freeze_key is True

        model_3, optimizer_3, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=optimizer_grouped_parameters_3,
        )
        optimizer_3.optimizer.freeze_step = 20
        data_loader = random_dataloader(
            model=model_3,
            total_samples=50,
            hidden_dim=hidden_dim,
            device=model_3.device,
        )
        for n, batch in enumerate(data_loader):
            loss = model_3(batch[0], batch[1])
            model_3.backward(loss)
            model_3.step()
        assert optimizer_3.optimizer.lamb_freeze_key is True
        # Test whether momentum mask stays the same after loading checkpoint
        assert ("exp_avg_mask"
                not in optimizer_3.param_groups[0]), f"Incorrect momentum mask"
        model_3.load_checkpoint(
            save_folder,
            tag=None,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        assert ("exp_avg_mask" not in optimizer_3.param_groups[0]
                ), f"Momentum mask should not change after loading checkpoint"
        # Test whether worker&server error is reset
        assert len(optimizer_3.optimizer.worker_errors
                   ) == 0, f"Incorrect worker error"
        assert len(optimizer_3.optimizer.server_errors
                   ) == 0, f"Incorrect server error"
        # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are reset
        for v in optimizer_3.state.values():
            assert v[
                "lamb_coeff_freeze"] == 0.0, f"Incorrect lamb_coeff_freeze"
            assert v["last_factor"] == 1.0, f"Incorrect last_factor"
            assert "scaling_coeff" not in v, f"Incorrect scaling_coeff"
        assert optimizer_3.optimizer.lamb_freeze_key is False
示例#12
0
    def test(self):
        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "OneBitLamb",
                "params": {
                    "lr": 0.00015,
                    "weight_decay": 0.01,
                    "max_coeff": 0.3,
                    "min_coeff": 0.01,
                    "freeze_step": 2,
                    "cuda_aware": False,
                    "comm_backend_name": "nccl",
                    "coeff_beta": 0.9,
                    "factor_max": 1.0,
                    "factor_min": 0.5,
                    "factor_threshold": 0.1,
                },
            },
            "gradient_clipping": 1.0,
            "fp16": {
                "enabled": True,
                "loss_scale": 0,
                "initial_scale_power": 16
            },
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim)
        param_optimizer = list(model.named_parameters())
        mask1 = torch.zeros_like(param_optimizer[0][1].data)
        for col in range(mask1.size()[1]):
            mask1[0][col] += 1
        optimizer_grouped_parameters = [
            {
                "params": [param_optimizer[0][1]],
                "weight_decay": 0.01,
                "exp_avg_mask": mask1,
            },
            {
                "params": [param_optimizer[1][1]],
                "weight_decay": 0.01
            },
        ]

        model, optimizer, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=optimizer_grouped_parameters,
        )
        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
        # Test whether the momentum mask works
        for v in optimizer.state.values():
            if v["exp_avg"].size() == mask1.size():
                assert torch.allclose(
                    v["exp_avg"],
                    v["exp_avg"].mul_(mask1.to(device=v["exp_avg"].device)),
                    atol=1e-07,
                ), f"Momentum mask is not working properly"
示例#13
0
    def test(self, tmpdir):
        config_dict = {
            "train_batch_size": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "ZeroOneAdam",
                "params": {
                    "lr": 0.00015,
                    "weight_decay": 0.01,
                    "var_freeze_step": 4,
                    "var_update_scaler": 1,
                    "local_step_scaler": 1,
                    "local_step_clipper": 2,
                    "cuda_aware": False,
                    "comm_backend_name": "nccl",
                },
            },
            "gradient_clipping": 1.0,
            "fp16": {
                "enabled": True,
                "loss_scale": 0,
                "initial_scale_power": 16
            },
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim)
        param_optimizer = list(model.named_parameters())
        mask1 = torch.zeros_like(param_optimizer[0][1].data)
        mask2 = torch.zeros_like(param_optimizer[0][1].data)
        for col in range(mask1.size()[1]):
            mask1[0][col] += 1
            mask2[1][col] += 1
        mask1 = torch.flatten(mask1)
        mask2 = torch.flatten(mask2)

        optimizer_grouped_parameters_1 = [
            {
                "params": [param_optimizer[0][1]],
                "weight_decay": 0.01,
                "exp_avg_mask": mask1,
            },
            {
                "params": [param_optimizer[1][1]],
                "weight_decay": 0.01
            },
        ]

        optimizer_grouped_parameters_2 = [
            {
                "params": [param_optimizer[0][1]],
                "weight_decay": 0.01,
                "exp_avg_mask": mask2,
            },
            {
                "params": [param_optimizer[1][1]],
                "weight_decay": 0.01
            },
        ]

        optimizer_grouped_parameters_3 = [
            {
                "params": [param_optimizer[0][1]],
                "weight_decay": 0.01
            },
            {
                "params": [param_optimizer[1][1]],
                "weight_decay": 0.01
            },
        ]

        model_1, optimizer_1, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=optimizer_grouped_parameters_1,
        )
        data_loader = random_dataloader(
            model=model_1,
            total_samples=10,
            hidden_dim=hidden_dim,
            device=model_1.device,
        )
        for n, batch in enumerate(data_loader):
            loss = model_1(batch[0], batch[1])
            model_1.backward(loss)
            model_1.step()
        # Test whether momentum mask still exist after saving checkpoint
        mask1 = mask1.to(
            device=optimizer_1.param_groups[0]["exp_avg_mask"].device)
        assert torch.allclose(optimizer_1.param_groups[0]["exp_avg_mask"],
                              mask1,
                              atol=1e-07), f"Incorrect momentum mask"
        save_folder = os.path.join(tmpdir, "saved_checkpoint")
        model_1.save_checkpoint(save_folder, tag=None)
        assert torch.allclose(
            optimizer_1.param_groups[0]["exp_avg_mask"], mask1, atol=1e-07
        ), f"Momentum mask should not change after saving checkpoint"

        model_2, optimizer_2, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=optimizer_grouped_parameters_2,
        )
        # Test whether momentum mask stays the same after loading checkpoint
        mask2 = mask2.to(
            device=optimizer_2.param_groups[0]["exp_avg_mask"].device)
        assert torch.allclose(optimizer_2.param_groups[0]["exp_avg_mask"],
                              mask2,
                              atol=1e-07), f"Incorrect momentum mask"
        model_2.load_checkpoint(
            save_folder,
            tag=None,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        assert torch.allclose(
            optimizer_2.param_groups[0]["exp_avg_mask"], mask2, atol=1e-07
        ), f"Momentum mask should not change after loading checkpoint"
        # Test whether worker&server error is reset
        for v in optimizer_2.state.values():
            assert "worker_error" not in v, f"Incorrect worker error"
            assert "server_error" not in v, f"Incorrect server error"

        model_3, optimizer_3, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=optimizer_grouped_parameters_3,
        )
        optimizer_3.optimizer.freeze_step = 20
        data_loader = random_dataloader(
            model=model_3,
            total_samples=50,
            hidden_dim=hidden_dim,
            device=model_3.device,
        )
        for n, batch in enumerate(data_loader):
            loss = model_3(batch[0], batch[1])
            model_3.backward(loss)
            model_3.step()
        # Test whether momentum mask stays the same after loading checkpoint
        assert ("exp_avg_mask"
                not in optimizer_3.param_groups[0]), f"Incorrect momentum mask"
        model_3.load_checkpoint(
            save_folder,
            tag=None,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        assert ("exp_avg_mask" not in optimizer_3.param_groups[0]
                ), f"Momentum mask should not change after loading checkpoint"
        # Test whether worker&server error is reset
        for v in optimizer_3.state.values():
            assert "worker_error" not in v, f"Incorrect worker error"
            assert "server_error" not in v, f"Incorrect server error"