def test_checkpoint_fused_optimizer(self, tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "Adam", "params": { "lr": 0.00015, "betas": [0.8, 0.999], "eps": 1e-8, "weight_decay": 3e-7 } }, "fp16": { "enabled": True } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] # Load & verify optimizer states checkpoint_correctness_verification(config_dict, models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, load_optimizer_states=True) # Ignore optimizer states checkpoint_correctness_verification(config_dict, models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, load_optimizer_states=False)
def test_checkpoint_moe(self, tmpdir, ep_size): if not required_torch_version(): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") config_dict = { "train_batch_size": 8, "steps_per_print": 1, "fp16": { "enabled": True } } hidden_dim = 16 models = [ SimpleMoEModel(hidden_dim=hidden_dim, num_experts=ep_size, ep_size=ep_size) for _ in range(2) ] optimizers = [torch.optim.AdamW(params=model.parameters()) for model in models] checkpoint_correctness_verification(config_dict, models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, load_optimizer_states=True, load_lr_scheduler_states=False, fp16=config_dict["fp16"]["enabled"], empty_tag=True, base_optimizers=optimizers, seq_dataloader=True)
def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states): if not required_torch_version(): pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") config_dict = { "train_batch_size": 8, "steps_per_print": 1, "optimizer": { "type": 'Adam', "params": { "lr": 0.00015, "betas": [0.8, 0.999], "eps": 1e-8, "weight_decay": 3e-7 } }, "fp16": { "enabled": True, "initial_scale_power": 8 }, "zero_optimization": { "stage": 2, } } hidden_dim = 16 models = [ SimpleMoEModel(hidden_dim=hidden_dim, num_experts=ep_size, ep_size=ep_size) for _ in range(2) ] # param group must have a random unique name (for now) # TODO: clean-up this requirement, the unique name should not be required here param_groups = [{ 'params': [p for p in model.parameters()], 'name': 'random-unique-name' } for model in models] params = [ split_params_into_different_moe_groups_for_optimizer(group) for group in param_groups ] optimizers = [torch.optim.AdamW(params=param) for param in params] checkpoint_correctness_verification(config_dict, models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, load_optimizer_states=load_optim_states, load_lr_scheduler_states=False, fp16=config_dict["fp16"]["enabled"], empty_tag=True, base_optimizers=optimizers, seq_dataloader=True)
def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": 'Adam', "params": { "lr": 0.00015, "betas": [0.8, 0.999], "eps": 1e-8, "weight_decay": 3e-7 } }, "fp16": { "enabled": True }, "zero_optimization": { "stage": zero_stage, "cpu_offload": use_cpu_offload }, "scheduler": { "type": "WarmupLR", "params": { "warmup_min_lr": 0, "warmup_max_lr": 0.001, "warmup_num_steps": 1000 } } } hidden_dim = 10 if zero_stage == 3: global DeepSpeedZeroOptimizer_Stage3 from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 with deepspeed.zero.Init(): models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] else: models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_optimizer_states=False, load_lr_scheduler_states=True)
def test_checkpoint_unfused_optimizer(self, tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "Lamb", "params": { "lr": 0.00015 } }, "gradient_clipping": 1.0, "fp16": { "enabled": True }, "scheduler": { "type": "OneCycle", "params": { "cycle_first_step_size": 1000, "cycle_first_stair_count": 500, "cycle_second_step_size": 1000, "cycle_second_stair_count": 500, "decay_step_size": 1000, "cycle_min_lr": 0.0001, "cycle_max_lr": 0.0010, "decay_lr_rate": 0.001, "cycle_min_mom": 0.85, "cycle_max_mom": 0.99, "decay_mom_rate": 0.0 } } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 models = [SimpleModel(hidden_dim, empty_grad=False) for _ in range(2)] # Load & verify optimizer states checkpoint_correctness_verification(config_dict, models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, load_optimizer_states=True) # Ignore optimizer states checkpoint_correctness_verification(config_dict, models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, load_optimizer_states=False)
def test_checkpoint_pipe_engine(self, zero_stage, tmpdir): config_dict = { "train_batch_size": 2, "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, "optimizer": { "type": "Adam", "params": { "lr": 1e-5 } }, "zero_optimization": { "stage": zero_stage }, "fp16": { "enabled": zero_stage > 0 }, "scheduler": { "type": "OneCycle", "params": { "cycle_first_step_size": 1000, "cycle_first_stair_count": 500, "cycle_second_step_size": 1000, "cycle_second_stair_count": 500, "decay_step_size": 1000, "cycle_min_lr": 0.0001, "cycle_max_lr": 0.0010, "decay_lr_rate": 0.001, "cycle_min_mom": 0.85, "cycle_max_mom": 0.99, "decay_mom_rate": 0.0 } } } models = [LinearStackPipe(num_stages=2) for _ in range(2)] checkpoint_correctness_verification( config_dict=config_dict, models=models, hidden_dim=models[0].hidden_dim, tmpdir=tmpdir, fp16=config_dict['fp16']['enabled'], load_optimizer_states=True, load_lr_scheduler_states=True, train_batch=True)
def test_existing_latest(self, tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "Adam", "params": { "lr": 0.00015 } } } hidden_dim = 10 models = [SimpleModel(hidden_dim=hidden_dim) for _ in range(2)] checkpoint_correctness_verification(config_dict=config_dict, models=models, hidden_dim=hidden_dim, tmpdir=tmpdir, load_optimizer_states=True, load_lr_scheduler_states=False, fp16=False, empty_tag=True)