def test_onebitlamb_checkpointing(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "OneBitLamb", "params": { "lr": 0.00015, "weight_decay": 0.01, "max_coeff": 0.3, "min_coeff": 0.01, "freeze_step": 2, "cuda_aware": False, "comm_backend_name": "nccl", "coeff_beta": 0.9, "factor_max": 1.0, "factor_min": 0.5, "factor_threshold": 0.1 } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) param_optimizer = list(model.named_parameters()) mask1 = torch.zeros_like(param_optimizer[0][1].data) mask2 = torch.zeros_like(param_optimizer[0][1].data) for col in range(mask1.size()[1]): mask1[0][col] += 1 mask2[1][col] += 1 optimizer_grouped_parameters_1 = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01, 'exp_avg_mask': mask1 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] optimizer_grouped_parameters_2 = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01, 'exp_avg_mask': mask2 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] optimizer_grouped_parameters_3 = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] @distributed_test(world_size=[2]) def _test_onebitlamb_checkpointing(mask1, mask2, args, model, hidden_dim): model_1, optimizer_1, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=optimizer_grouped_parameters_1) data_loader = random_dataloader(model=model_1, total_samples=10, hidden_dim=hidden_dim, device=model_1.device) for n, batch in enumerate(data_loader): loss = model_1(batch[0], batch[1]) model_1.backward(loss) model_1.step() # Test whether momentum mask still exist after saving checkpoint assert optimizer_1.optimizer.lamb_freeze_key is True mask1 = mask1.to( device=optimizer_1.param_groups[0]['exp_avg_mask'].device) assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" scaling_coeff_1 = [] for v in optimizer_1.state.values(): assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" scaling_coeff_1.append(v['scaling_coeff']) save_folder = os.path.join(tmpdir, 'saved_checkpoint') model_1.save_checkpoint(save_folder, tag=None) assert torch.allclose( optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07 ), f"Momentum mask should not change after saving checkpoint" model_2, optimizer_2, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=optimizer_grouped_parameters_2) # Test whether momentum mask stays the same after loading checkpoint mask2 = mask2.to( device=optimizer_2.param_groups[0]['exp_avg_mask'].device) assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" model_2.load_checkpoint(save_folder, tag=None, load_optimizer_states=True, load_lr_scheduler_states=True) assert torch.allclose( optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07 ), f"Momentum mask should not change after loading checkpoint" # Test whether worker&server error is resetted assert len(optimizer_2.optimizer.worker_errors ) == 0, f"Incorrect worker error" assert len(optimizer_2.optimizer.server_errors ) == 0, f"Incorrect server error" # Test whether scaling_coeffs is loaded correctly scaling_coeff_2 = [] for v in optimizer_2.state.values(): assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" scaling_coeff_2.append(v['scaling_coeff']) assert list(sorted(scaling_coeff_2)) == list( sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs" assert optimizer_2.optimizer.lamb_freeze_key is True model_3, optimizer_3, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=optimizer_grouped_parameters_3) optimizer_3.optimizer.freeze_step = 20 data_loader = random_dataloader(model=model_3, total_samples=50, hidden_dim=hidden_dim, device=model_3.device) for n, batch in enumerate(data_loader): loss = model_3(batch[0], batch[1]) model_3.backward(loss) model_3.step() assert optimizer_3.optimizer.lamb_freeze_key is True # Test whether momentum mask stays the same after loading checkpoint assert 'exp_avg_mask' not in optimizer_3.param_groups[ 0], f"Incorrect momentum mask" model_3.load_checkpoint(save_folder, tag=None, load_optimizer_states=True, load_lr_scheduler_states=True) assert 'exp_avg_mask' not in optimizer_3.param_groups[ 0], f"Momentum mask should not change after loading checkpoint" # Test whether worker&server error is resetted assert len(optimizer_3.optimizer.worker_errors ) == 0, f"Incorrect worker error" assert len(optimizer_3.optimizer.server_errors ) == 0, f"Incorrect server error" # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are resetted for v in optimizer_3.state.values(): assert v[ 'lamb_coeff_freeze'] == 0.0, f"Incorrect lamb_coeff_freeze" assert v['last_factor'] == 1.0, f"Incorrect last_factor" assert 'scaling_coeff' not in v, f"Incorrect scaling_coeff" assert optimizer_3.optimizer.lamb_freeze_key is False _test_onebitlamb_checkpointing(mask1, mask2, args=args, model=model, hidden_dim=hidden_dim)
def test_onebitlamb_exp_avg_mask(tmpdir): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "OneBitLamb", "params": { "lr": 0.00015, "weight_decay": 0.01, "max_coeff": 0.3, "min_coeff": 0.01, "freeze_step": 2, "cuda_aware": False, "comm_backend_name": "nccl", "coeff_beta": 0.9, "factor_max": 1.0, "factor_min": 0.5, "factor_threshold": 0.1 } }, "gradient_clipping": 1.0, "fp16": { "enabled": True, "loss_scale": 0, "initial_scale_power": 16 } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 model = SimpleModel(hidden_dim) param_optimizer = list(model.named_parameters()) mask1 = torch.zeros_like(param_optimizer[0][1].data) for col in range(mask1.size()[1]): mask1[0][col] += 1 optimizer_grouped_parameters = [{ 'params': [param_optimizer[0][1]], 'weight_decay': 0.01, 'exp_avg_mask': mask1 }, { 'params': [param_optimizer[1][1]], 'weight_decay': 0.01 }] @distributed_test(world_size=[2]) def _test_onebitlamb_exp_avg_mask(args, model, hidden_dim): model, optimizer, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=optimizer_grouped_parameters) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() # Test whether the momentum mask works for v in optimizer.state.values(): if v['exp_avg'].size() == mask1.size(): assert torch.allclose( v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" _test_onebitlamb_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim)