def _test_lr_range_test(args, model, hidden_dim, min_lr, step_size, staircase): model, _, _, lr_scheduler = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=max(50, step_size * 2), hidden_dim=hidden_dim, device=model.device, dtype=torch.float) step_lrs = [] for _, batch in enumerate(data_loader): step_lrs.append(lr_scheduler.get_lr()) loss = model(batch[0], batch[1]) model.backward(loss) model.step() # Verify starting lr assert step_lrs[0] == min_lr if staircase: # Verify staircase increasing lr _verify_staircase_increase(step_lrs, step_size) else: # Verify continuous increasing lr _verify_continuous_increase(step_lrs)
def _test_onecycle_mom(args, model, hidden_dim, min_mom, max_mom, step_size, decay_rate): model, _, _, lr_scheduler = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=max(50, step_size * 3), hidden_dim=hidden_dim, device=model.device, dtype=torch.float) step_moms = [] for _, batch in enumerate(data_loader): step_moms.append(lr_scheduler.get_mom()) loss = model(batch[0], batch[1]) model.backward(loss) model.step() # Verify starting lr assert step_moms[0][0][0] == max_mom # Verify peak lr assert step_moms[step_size][0][0] == min_mom # Verify decreasing phase _verify_continuous_decrease(step_moms[:step_size]) # Verify increasing phase _verify_continuous_increase(step_moms[step_size:(step_size * 2)]) # Verify decay phase if decay_rate > 0: _verify_continuous_increase(step_moms[(step_size * 2):])
def _test_lr_warmup_decay_schedule(args, model, hidden_dim, schedule_params, num_steps): model, _, _, lr_scheduler = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=num_steps * 2, hidden_dim=hidden_dim, device=model.device, dtype=torch.float) step_lrs = [] for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() step_lrs.append(lr_scheduler.get_lr()) # Verify initial lr assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]] # Verify lr at warmup completion warmup_num_steps = schedule_params[WARMUP_NUM_STEPS] warmup_max_lr = [schedule_params[WARMUP_MAX_LR]] assert step_lrs[warmup_num_steps] == warmup_max_lr # Verify decay phase previous_lr = warmup_max_lr for lr in step_lrs[warmup_num_steps + 1:]: assert lr < previous_lr previous_lr = lr
def _helper(): model = SimpleModel(hidden_dim=10) model, _, _, _ = deepspeed.initialize(model=model, config=config) data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=10, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1])
def _test_zero_unbalanced_gradients(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) run_unbalanced_gradients(model, data_loader)
def _test_lamb_fp16_basic(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step()
def _test_zero3_repeat_forward_loop(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) for i, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step()
def _test_adam_fp32_empty_grad(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device, dtype=torch.float) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step()
def _test_non_pld_model(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=1, hidden_dim=hidden_dim, device=model.device) for i, batch in enumerate(data_loader): with pytest.raises(TypeError): loss = model(batch[0], batch[1])
def _test_adam_amp_basic(args, model, hidden_dim): optimizer = torch.optim.Adam(params=model.parameters()) model, _, _, _ = deepspeed.initialize(args=args, model=model, optimizer=optimizer) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step()
def _test_dist_init_true(args, model, hidden_dim): model, _, _,_ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters(), dist_init_required=True) data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step()
def _test_scheduler_optimizer_parity(args, model, hidden_dim): model, _, _, lr_scheduler = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device, dtype=torch.float) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() assert lr_scheduler.get_lr() == model.get_lr()
def _test_get_lr_before_train(args, model, hidden_dim): model, _, _, lr_scheduler = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device, dtype=torch.float) for n, batch in enumerate(data_loader): # get lr before training starts lr_scheduler.get_lr() loss = model(batch[0], batch[1]) model.backward(loss) model.step()
def _test_adam_fp16_zero_onecycle_compatibility(args, zero_stage, hidden_dim): model = SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step()
def _test_adamw_fp16_empty_grad(args, model, hidden_dim): optimizer = torch.optim.AdamW(params=model.parameters()) model, _, _, _ = deepspeed.initialize(args=args, model=model, optimizer=optimizer, dist_init_required=False) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step()
def _test_curriculum_scheduler_fixed_linear(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=20, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss, seqlen = model(batch[0], batch[1]) model.backward(loss) model.step() if n + 1 in ground_truths: true_seqlen = ground_truths[n + 1] print('at step {} the seqlen is {}'.format(n + 1, seqlen)) assert seqlen == true_seqlen, f"Incorrect curriculum schedule"
def _go(args): model = SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=10, hidden_dim=hidden_dim, device=model.device) for _, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step()
def _test_flops_profiler_in_ds_training(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device, dtype=torch.half) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() if n == 3: break assert model.flops_profiler.flops == 100 assert model.flops_profiler.params == 110
def _test_onebitlamb_checkpointing_overflow(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=100, hidden_dim=hidden_dim, device=model.device) save_folder = os.path.join(tmpdir, 'saved_checkpoint') for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) if dist.get_rank() == 0 and n >= 10: loss = loss * 1000000.0 model.backward(loss) dist.barrier() model.step() dist.barrier() model.save_checkpoint(save_folder, tag=None)
def _test_pld_model(args, model, hidden_dim, theta, gamma): model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for i, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() expected_theta = (1. - theta) * np.exp(-gamma * i) + theta actual_theta = model.get_pld_theta() assert expected_theta == actual_theta
def _test_zero_empty_partition(args): hidden_dim = 1 model = SimpleModel(hidden_dim) # Ensure model has 2 parameters, to cause empty partition with DP=3 assert len(list(model.parameters())) == 2 model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) # Now make sure things work.. data_loader = random_dataloader(model=model, total_samples=1, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step()
def _go(hidden_dim): with deepspeed.zero.Init(enabled=zero_stage == 3, config_dict_or_path=ds_config): model = SimpleModel(hidden_dim, nlayers=78) print('total number of parameters:', sum([p.numel() for p in model.parameters()])) see_memory_usage('pre-init', force=True) model, _, _, _ = deepspeed.initialize(model=model, config=ds_config) see_memory_usage('post-init', force=True) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device, dtype=torch.half) print(f"optimizer={model.optimizer}") for batch in data_loader: model(batch[0], batch[1]) see_memory_usage('post-fwds', force=True)
def checkpoint_correctness_verification(args, model, hidden_dim, tmpdir, load_optimizer_states=False, load_lr_scheduler_states=False, fp16=True): dtype = torch.half if fp16 else torch.float32 ds_model, _, _,_ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=ds_model, total_samples=50, hidden_dim=hidden_dim, device=ds_model.device, dtype=dtype) for n, batch in enumerate(data_loader): loss = ds_model(batch[0], batch[1]) ds_model.backward(loss) ds_model.step() trained_model = ds_model save_folder = os.path.join(tmpdir, 'saved_checkpoint') save_tag = '1' trained_model.save_checkpoint(save_folder, save_tag) loaded_model, _, _,_ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) loaded_model.load_checkpoint(save_folder, save_tag, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_lr_scheduler_states) compare_model_states(trained_model, loaded_model) if load_optimizer_states: compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16) if load_lr_scheduler_states: compare_lr_scheduler_states(trained_model, loaded_model)
def _test_zero_static_scale(args): hidden_dim = 10 model = SimpleModel(hidden_dim) model, optim, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) # Ensure the static scaler is configured. assert optim.dynamic_loss_scale == False assert optim.loss_scaler.loss_scale == 138. # Now make sure things work.. data_loader = random_dataloader(model=model, total_samples=10, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step()
def _test_onebitlamb_exp_avg_mask(args, model, hidden_dim): model, optimizer, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=optimizer_grouped_parameters) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() # Test whether the momentum mask works for v in optimizer.state.values(): if v['exp_avg'].size() == mask1.size(): assert torch.allclose( v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly"
def _helper(): parser = argparse.ArgumentParser() args = parser.parse_args(args='') args.deepscale_config = config_path args.local_rank = 0 hidden_dim = 10 model = SimpleModel(hidden_dim=hidden_dim) model, _, _, _ = deepspeed.initialize(args=args, model=model) data_loader = random_dataloader(model=model, total_samples=5, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) with pytest.raises(AssertionError): model.backward(loss) with pytest.raises(AssertionError): model.step()
def _test_stage2_find_unused_parameters(args, model, hidden_dim): model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=10, hidden_dim=hidden_dim, device=model.device) def _loop(): for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() if not find_unused_parameters: with pytest.raises(AssertionError) as e: _loop() assert e.value.args and 'find_unused_parameters' in e.value.args[0] else: _loop()
def checkpoint_correctness_verification(args, model, hidden_dim, load_optimizer_states=True): ds_model, _, _,_ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=ds_model, total_samples=50, hidden_dim=hidden_dim, device=ds_model.device) for n, batch in enumerate(data_loader): loss = ds_model(batch[0], batch[1]) ds_model.backward(loss) ds_model.step() trained_model = ds_model save_folder = 'saved_checkpoint' save_tag = '1' trained_model.save_checkpoint(save_folder, save_tag) loaded_model, _, _,_ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) loaded_model.load_checkpoint(save_folder, save_tag, load_optimizer_states=load_optimizer_states) if load_optimizer_states: compare_optimizer_states(trained_model, loaded_model, hidden_dim) else: compare_model_states(trained_model, loaded_model)
def _test_onebitlamb_checkpointing(mask1, mask2, args, model, hidden_dim): model_1, optimizer_1, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=optimizer_grouped_parameters_1) data_loader = random_dataloader(model=model_1, total_samples=10, hidden_dim=hidden_dim, device=model_1.device) for n, batch in enumerate(data_loader): loss = model_1(batch[0], batch[1]) model_1.backward(loss) model_1.step() # Test whether momentum mask still exist after saving checkpoint assert optimizer_1.optimizer.lamb_freeze_key is True mask1 = mask1.to( device=optimizer_1.param_groups[0]['exp_avg_mask'].device) assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" scaling_coeff_1 = [] for v in optimizer_1.state.values(): assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" scaling_coeff_1.append(v['scaling_coeff']) save_folder = os.path.join(tmpdir, 'saved_checkpoint') model_1.save_checkpoint(save_folder, tag=None) assert torch.allclose( optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07 ), f"Momentum mask should not change after saving checkpoint" model_2, optimizer_2, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=optimizer_grouped_parameters_2) # Test whether momentum mask stays the same after loading checkpoint mask2 = mask2.to( device=optimizer_2.param_groups[0]['exp_avg_mask'].device) assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" model_2.load_checkpoint(save_folder, tag=None, load_optimizer_states=True, load_lr_scheduler_states=True) assert torch.allclose( optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07 ), f"Momentum mask should not change after loading checkpoint" # Test whether worker&server error is resetted assert len(optimizer_2.optimizer.worker_errors ) == 0, f"Incorrect worker error" assert len(optimizer_2.optimizer.server_errors ) == 0, f"Incorrect server error" # Test whether scaling_coeffs is loaded correctly scaling_coeff_2 = [] for v in optimizer_2.state.values(): assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" scaling_coeff_2.append(v['scaling_coeff']) assert list(sorted(scaling_coeff_2)) == list( sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs" assert optimizer_2.optimizer.lamb_freeze_key is True model_3, optimizer_3, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=optimizer_grouped_parameters_3) optimizer_3.optimizer.freeze_step = 20 data_loader = random_dataloader(model=model_3, total_samples=50, hidden_dim=hidden_dim, device=model_3.device) for n, batch in enumerate(data_loader): loss = model_3(batch[0], batch[1]) model_3.backward(loss) model_3.step() assert optimizer_3.optimizer.lamb_freeze_key is True # Test whether momentum mask stays the same after loading checkpoint assert 'exp_avg_mask' not in optimizer_3.param_groups[ 0], f"Incorrect momentum mask" model_3.load_checkpoint(save_folder, tag=None, load_optimizer_states=True, load_lr_scheduler_states=True) assert 'exp_avg_mask' not in optimizer_3.param_groups[ 0], f"Momentum mask should not change after loading checkpoint" # Test whether worker&server error is resetted assert len(optimizer_3.optimizer.worker_errors ) == 0, f"Incorrect worker error" assert len(optimizer_3.optimizer.server_errors ) == 0, f"Incorrect server error" # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are resetted for v in optimizer_3.state.values(): assert v[ 'lamb_coeff_freeze'] == 0.0, f"Incorrect lamb_coeff_freeze" assert v['last_factor'] == 1.0, f"Incorrect last_factor" assert 'scaling_coeff' not in v, f"Incorrect scaling_coeff" assert optimizer_3.optimizer.lamb_freeze_key is False
def _test_zero_to_fp32(): class MyModel(torch.nn.Module): def __init__(self, hidden_dim, n_layers): super().__init__() self.ll = torch.nn.ModuleList( torch.nn.Linear(hidden_dim, hidden_dim) for i in range(n_layers)) self.cross_entropy_loss = torch.nn.CrossEntropyLoss() def forward(self, x, y): hidden = x for l in self.ll: hidden = l(hidden) return self.cross_entropy_loss(hidden, y) args = args_from_dict(tmpdir, config_dict) hidden_dim = 3 world_size = dist.get_world_size() n_layers = world_size * 2 model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers) optim_groups = [ { "params": [l.weight for l in model.ll], "weight_decay": 0.01, }, { "params": [l.bias for l in model.ll], "weight_decay": 0.0 }, ] optim = torch.optim.SGD(optim_groups, lr=0.1) model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters(), optimizer=optim, ) data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) for i, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() model.save_checkpoint(tmpdir) # make sure all sides saved it dist.barrier() if zero_stage == 3: with deepspeed.zero.GatheredParameters(list( model.module.parameters(recurse=True)), modifier_rank=None): pass # this forces gathering the model #dump_state_dict(model) orig_state_dict = {} for name, param in model.module.named_parameters(): orig_state_dict[name] = param.detach().cpu() if dist.get_rank() == 0: fp32_model = load_state_dict_from_zero_checkpoint( model.module, tmpdir) #dump_state_dict(fp32_model) fp32_state_dict = fp32_model.state_dict() for name in orig_state_dict.keys(): # float() workaround for torch<1.6 assert torch.allclose(orig_state_dict[name].float(), fp32_state_dict[name].float())