Beispiel #1
0
def test_onebitlamb_checkpointing(tmpdir):
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
            "type": "OneBitLamb",
            "params": {
                "lr": 0.00015,
                "weight_decay": 0.01,
                "max_coeff": 0.3,
                "min_coeff": 0.01,
                "freeze_step": 2,
                "cuda_aware": False,
                "comm_backend_name": "nccl",
                "coeff_beta": 0.9,
                "factor_max": 1.0,
                "factor_min": 0.5,
                "factor_threshold": 0.1
            }
        },
        "gradient_clipping": 1.0,
        "fp16": {
            "enabled": True,
            "loss_scale": 0,
            "initial_scale_power": 16
        }
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim)
    param_optimizer = list(model.named_parameters())
    mask1 = torch.zeros_like(param_optimizer[0][1].data)
    mask2 = torch.zeros_like(param_optimizer[0][1].data)
    for col in range(mask1.size()[1]):
        mask1[0][col] += 1
        mask2[1][col] += 1

    optimizer_grouped_parameters_1 = [{
        'params': [param_optimizer[0][1]],
        'weight_decay': 0.01,
        'exp_avg_mask': mask1
    }, {
        'params': [param_optimizer[1][1]],
        'weight_decay': 0.01
    }]

    optimizer_grouped_parameters_2 = [{
        'params': [param_optimizer[0][1]],
        'weight_decay': 0.01,
        'exp_avg_mask': mask2
    }, {
        'params': [param_optimizer[1][1]],
        'weight_decay': 0.01
    }]

    optimizer_grouped_parameters_3 = [{
        'params': [param_optimizer[0][1]],
        'weight_decay': 0.01
    }, {
        'params': [param_optimizer[1][1]],
        'weight_decay': 0.01
    }]

    @distributed_test(world_size=[2])
    def _test_onebitlamb_checkpointing(mask1, mask2, args, model, hidden_dim):
        model_1, optimizer_1, _, _ = deepspeed.initialize(
            args=args,
            model=model,
            model_parameters=optimizer_grouped_parameters_1)
        data_loader = random_dataloader(model=model_1,
                                        total_samples=10,
                                        hidden_dim=hidden_dim,
                                        device=model_1.device)
        for n, batch in enumerate(data_loader):
            loss = model_1(batch[0], batch[1])
            model_1.backward(loss)
            model_1.step()
        # Test whether momentum mask still exist after saving checkpoint
        assert optimizer_1.optimizer.lamb_freeze_key is True
        mask1 = mask1.to(
            device=optimizer_1.param_groups[0]['exp_avg_mask'].device)
        assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'],
                              mask1,
                              atol=1e-07), f"Incorrect momentum mask"
        scaling_coeff_1 = []
        for v in optimizer_1.state.values():
            assert 'scaling_coeff' in v, f"Incorrect scaling_coeff"
            scaling_coeff_1.append(v['scaling_coeff'])
        save_folder = os.path.join(tmpdir, 'saved_checkpoint')
        model_1.save_checkpoint(save_folder, tag=None)
        assert torch.allclose(
            optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07
        ), f"Momentum mask should not change after saving checkpoint"

        model_2, optimizer_2, _, _ = deepspeed.initialize(
            args=args,
            model=model,
            model_parameters=optimizer_grouped_parameters_2)
        # Test whether momentum mask stays the same after loading checkpoint
        mask2 = mask2.to(
            device=optimizer_2.param_groups[0]['exp_avg_mask'].device)
        assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'],
                              mask2,
                              atol=1e-07), f"Incorrect momentum mask"
        model_2.load_checkpoint(save_folder,
                                tag=None,
                                load_optimizer_states=True,
                                load_lr_scheduler_states=True)
        assert torch.allclose(
            optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07
        ), f"Momentum mask should not change after loading checkpoint"
        # Test whether worker&server error is resetted
        assert len(optimizer_2.optimizer.worker_errors
                   ) == 0, f"Incorrect worker error"
        assert len(optimizer_2.optimizer.server_errors
                   ) == 0, f"Incorrect server error"
        # Test whether scaling_coeffs is loaded correctly
        scaling_coeff_2 = []
        for v in optimizer_2.state.values():
            assert 'scaling_coeff' in v, f"Incorrect scaling_coeff"
            scaling_coeff_2.append(v['scaling_coeff'])
        assert list(sorted(scaling_coeff_2)) == list(
            sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs"
        assert optimizer_2.optimizer.lamb_freeze_key is True

        model_3, optimizer_3, _, _ = deepspeed.initialize(
            args=args,
            model=model,
            model_parameters=optimizer_grouped_parameters_3)
        optimizer_3.optimizer.freeze_step = 20
        data_loader = random_dataloader(model=model_3,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model_3.device)
        for n, batch in enumerate(data_loader):
            loss = model_3(batch[0], batch[1])
            model_3.backward(loss)
            model_3.step()
        assert optimizer_3.optimizer.lamb_freeze_key is True
        # Test whether momentum mask stays the same after loading checkpoint
        assert 'exp_avg_mask' not in optimizer_3.param_groups[
            0], f"Incorrect momentum mask"
        model_3.load_checkpoint(save_folder,
                                tag=None,
                                load_optimizer_states=True,
                                load_lr_scheduler_states=True)
        assert 'exp_avg_mask' not in optimizer_3.param_groups[
            0], f"Momentum mask should not change after loading checkpoint"
        # Test whether worker&server error is resetted
        assert len(optimizer_3.optimizer.worker_errors
                   ) == 0, f"Incorrect worker error"
        assert len(optimizer_3.optimizer.server_errors
                   ) == 0, f"Incorrect server error"
        # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are resetted
        for v in optimizer_3.state.values():
            assert v[
                'lamb_coeff_freeze'] == 0.0, f"Incorrect lamb_coeff_freeze"
            assert v['last_factor'] == 1.0, f"Incorrect last_factor"
            assert 'scaling_coeff' not in v, f"Incorrect scaling_coeff"
        assert optimizer_3.optimizer.lamb_freeze_key is False

    _test_onebitlamb_checkpointing(mask1,
                                   mask2,
                                   args=args,
                                   model=model,
                                   hidden_dim=hidden_dim)
Beispiel #2
0
def test_onebitlamb_exp_avg_mask(tmpdir):
    config_dict = {
        "train_batch_size": 2,
        "steps_per_print": 1,
        "optimizer": {
            "type": "OneBitLamb",
            "params": {
                "lr": 0.00015,
                "weight_decay": 0.01,
                "max_coeff": 0.3,
                "min_coeff": 0.01,
                "freeze_step": 2,
                "cuda_aware": False,
                "comm_backend_name": "nccl",
                "coeff_beta": 0.9,
                "factor_max": 1.0,
                "factor_min": 0.5,
                "factor_threshold": 0.1
            }
        },
        "gradient_clipping": 1.0,
        "fp16": {
            "enabled": True,
            "loss_scale": 0,
            "initial_scale_power": 16
        }
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 10

    model = SimpleModel(hidden_dim)
    param_optimizer = list(model.named_parameters())
    mask1 = torch.zeros_like(param_optimizer[0][1].data)
    for col in range(mask1.size()[1]):
        mask1[0][col] += 1
    optimizer_grouped_parameters = [{
        'params': [param_optimizer[0][1]],
        'weight_decay': 0.01,
        'exp_avg_mask': mask1
    }, {
        'params': [param_optimizer[1][1]],
        'weight_decay': 0.01
    }]

    @distributed_test(world_size=[2])
    def _test_onebitlamb_exp_avg_mask(args, model, hidden_dim):
        model, optimizer, _, _ = deepspeed.initialize(
            args=args,
            model=model,
            model_parameters=optimizer_grouped_parameters)
        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
        # Test whether the momentum mask works
        for v in optimizer.state.values():
            if v['exp_avg'].size() == mask1.size():
                assert torch.allclose(
                    v['exp_avg'],
                    v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)),
                    atol=1e-07), f"Momentum mask is not working properly"

    _test_onebitlamb_exp_avg_mask(args=args,
                                  model=model,
                                  hidden_dim=hidden_dim)