Ejemplo n.º 1
0
    def _test_lr_range_test(args, model, hidden_dim, min_lr, step_size, staircase):
        model, _, _, lr_scheduler = deepspeed.initialize(args=args,
                                                         model=model,
                                                         model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=max(50,
                                                          step_size * 2),
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)

        step_lrs = []
        for _, batch in enumerate(data_loader):
            step_lrs.append(lr_scheduler.get_lr())
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

        # Verify starting lr
        assert step_lrs[0] == min_lr

        if staircase:
            # Verify staircase increasing lr
            _verify_staircase_increase(step_lrs, step_size)
        else:
            # Verify continuous increasing lr
            _verify_continuous_increase(step_lrs)
Ejemplo n.º 2
0
    def _test_onecycle_mom(args, model, hidden_dim, min_mom, max_mom,
                           step_size, decay_rate):
        model, _, _, lr_scheduler = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=max(50, step_size * 3),
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)

        step_moms = []
        for _, batch in enumerate(data_loader):
            step_moms.append(lr_scheduler.get_mom())
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

        # Verify starting lr
        assert step_moms[0][0][0] == max_mom

        # Verify peak lr
        assert step_moms[step_size][0][0] == min_mom

        # Verify decreasing phase
        _verify_continuous_decrease(step_moms[:step_size])

        # Verify increasing phase
        _verify_continuous_increase(step_moms[step_size:(step_size * 2)])

        # Verify decay phase
        if decay_rate > 0:
            _verify_continuous_increase(step_moms[(step_size * 2):])
Ejemplo n.º 3
0
    def _test_lr_warmup_decay_schedule(args,
                                       model,
                                       hidden_dim,
                                       schedule_params,
                                       num_steps):
        model, _, _, lr_scheduler = deepspeed.initialize(args=args,
                                                         model=model,
                                                         model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=num_steps * 2,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.float)
        step_lrs = []
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
            step_lrs.append(lr_scheduler.get_lr())

        # Verify initial lr
        assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]]

        # Verify lr at warmup completion
        warmup_num_steps = schedule_params[WARMUP_NUM_STEPS]
        warmup_max_lr = [schedule_params[WARMUP_MAX_LR]]
        assert step_lrs[warmup_num_steps] == warmup_max_lr

        # Verify decay phase
        previous_lr = warmup_max_lr
        for lr in step_lrs[warmup_num_steps + 1:]:
            assert lr < previous_lr
            previous_lr = lr
Ejemplo n.º 4
0
 def _helper():
     model = SimpleModel(hidden_dim=10)
     model, _, _, _ = deepspeed.initialize(model=model, config=config)
     data_loader = random_dataloader(model=model,
                                     total_samples=5,
                                     hidden_dim=10,
                                     device=model.device)
     for n, batch in enumerate(data_loader):
         loss = model(batch[0], batch[1])
Ejemplo n.º 5
0
    def _test_zero_unbalanced_gradients(args, model, hidden_dim):
        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=16,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        run_unbalanced_gradients(model, data_loader)
Ejemplo n.º 6
0
 def _test_lamb_fp16_basic(args, model, hidden_dim):
     model, _, _, _ = deepspeed.initialize(
         args=args, model=model, model_parameters=model.parameters())
     data_loader = random_dataloader(model=model,
                                     total_samples=50,
                                     hidden_dim=hidden_dim,
                                     device=model.device)
     for n, batch in enumerate(data_loader):
         loss = model(batch[0], batch[1])
         model.backward(loss)
         model.step()
Ejemplo n.º 7
0
    def _test_zero3_repeat_forward_loop(args, model, hidden_dim):
        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=16,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for i, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
Ejemplo n.º 8
0
 def _test_adam_fp32_empty_grad(args, model, hidden_dim):
     model, _, _, _ = deepspeed.initialize(
         args=args, model=model, model_parameters=model.parameters())
     data_loader = random_dataloader(model=model,
                                     total_samples=50,
                                     hidden_dim=hidden_dim,
                                     device=model.device,
                                     dtype=torch.float)
     for n, batch in enumerate(data_loader):
         loss = model(batch[0], batch[1])
         model.backward(loss)
         model.step()
Ejemplo n.º 9
0
    def _test_non_pld_model(args, model, hidden_dim):
        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=1,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for i, batch in enumerate(data_loader):
            with pytest.raises(TypeError):
                loss = model(batch[0], batch[1])
Ejemplo n.º 10
0
 def _test_adam_amp_basic(args, model, hidden_dim):
     optimizer = torch.optim.Adam(params=model.parameters())
     model, _, _, _ = deepspeed.initialize(args=args,
                                           model=model,
                                           optimizer=optimizer)
     data_loader = random_dataloader(model=model,
                                     total_samples=50,
                                     hidden_dim=hidden_dim,
                                     device=model.device)
     for n, batch in enumerate(data_loader):
         loss = model(batch[0], batch[1])
         model.backward(loss)
         model.step()
Ejemplo n.º 11
0
 def _test_dist_init_true(args, model, hidden_dim):
     model, _, _,_ = deepspeed.initialize(args=args,
                                          model=model,
                                          model_parameters=model.parameters(),
                                          dist_init_required=True)
     data_loader = random_dataloader(model=model,
                                     total_samples=5,
                                     hidden_dim=hidden_dim,
                                     device=model.device)
     for n, batch in enumerate(data_loader):
         loss = model(batch[0], batch[1])
         model.backward(loss)
         model.step()
Ejemplo n.º 12
0
 def _test_scheduler_optimizer_parity(args, model, hidden_dim):
     model, _, _, lr_scheduler = deepspeed.initialize(
         args=args, model=model, model_parameters=model.parameters())
     data_loader = random_dataloader(model=model,
                                     total_samples=50,
                                     hidden_dim=hidden_dim,
                                     device=model.device,
                                     dtype=torch.float)
     for n, batch in enumerate(data_loader):
         loss = model(batch[0], batch[1])
         model.backward(loss)
         model.step()
         assert lr_scheduler.get_lr() == model.get_lr()
Ejemplo n.º 13
0
 def _test_get_lr_before_train(args, model, hidden_dim):
     model, _, _, lr_scheduler = deepspeed.initialize(
         args=args, model=model, model_parameters=model.parameters())
     data_loader = random_dataloader(model=model,
                                     total_samples=50,
                                     hidden_dim=hidden_dim,
                                     device=model.device,
                                     dtype=torch.float)
     for n, batch in enumerate(data_loader):
         # get lr before training starts
         lr_scheduler.get_lr()
         loss = model(batch[0], batch[1])
         model.backward(loss)
         model.step()
Ejemplo n.º 14
0
    def _test_adam_fp16_zero_onecycle_compatibility(args, zero_stage,
                                                    hidden_dim):
        model = SimpleModel(hidden_dim)

        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
Ejemplo n.º 15
0
 def _test_adamw_fp16_empty_grad(args, model, hidden_dim):
     optimizer = torch.optim.AdamW(params=model.parameters())
     model, _, _, _ = deepspeed.initialize(args=args,
                                           model=model,
                                           optimizer=optimizer,
                                           dist_init_required=False)
     data_loader = random_dataloader(model=model,
                                     total_samples=50,
                                     hidden_dim=hidden_dim,
                                     device=model.device)
     for n, batch in enumerate(data_loader):
         loss = model(batch[0], batch[1])
         model.backward(loss)
         model.step()
 def _test_curriculum_scheduler_fixed_linear(args, model, hidden_dim):
     model, _, _, _ = deepspeed.initialize(
         args=args, model=model, model_parameters=model.parameters())
     data_loader = random_dataloader(model=model,
                                     total_samples=20,
                                     hidden_dim=hidden_dim,
                                     device=model.device)
     for n, batch in enumerate(data_loader):
         loss, seqlen = model(batch[0], batch[1])
         model.backward(loss)
         model.step()
         if n + 1 in ground_truths:
             true_seqlen = ground_truths[n + 1]
             print('at step {} the seqlen is {}'.format(n + 1, seqlen))
             assert seqlen == true_seqlen, f"Incorrect curriculum schedule"
Ejemplo n.º 17
0
    def _go(args):
        model = SimpleModel(hidden_dim)

        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=10,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for _, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
Ejemplo n.º 18
0
    def _test_flops_profiler_in_ds_training(args, model, hidden_dim):
        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.half)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
            if n == 3: break
        assert model.flops_profiler.flops == 100
        assert model.flops_profiler.params == 110
Ejemplo n.º 19
0
 def _test_onebitlamb_checkpointing_overflow(args, model, hidden_dim):
     model, _, _, _ = deepspeed.initialize(
         args=args, model=model, model_parameters=model.parameters())
     data_loader = random_dataloader(model=model,
                                     total_samples=100,
                                     hidden_dim=hidden_dim,
                                     device=model.device)
     save_folder = os.path.join(tmpdir, 'saved_checkpoint')
     for n, batch in enumerate(data_loader):
         loss = model(batch[0], batch[1])
         if dist.get_rank() == 0 and n >= 10:
             loss = loss * 1000000.0
         model.backward(loss)
         dist.barrier()
         model.step()
         dist.barrier()
         model.save_checkpoint(save_folder, tag=None)
Ejemplo n.º 20
0
    def _test_pld_model(args, model, hidden_dim, theta, gamma):
        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for i, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

            expected_theta = (1. - theta) * np.exp(-gamma * i) + theta
            actual_theta = model.get_pld_theta()
            assert expected_theta == actual_theta
Ejemplo n.º 21
0
    def _test_zero_empty_partition(args):
        hidden_dim = 1
        model = SimpleModel(hidden_dim)
        # Ensure model has 2 parameters, to cause empty partition with DP=3
        assert len(list(model.parameters())) == 2
        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())

        # Now make sure things work..
        data_loader = random_dataloader(model=model,
                                        total_samples=1,
                                        hidden_dim=hidden_dim,
                                        device=model.device)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
Ejemplo n.º 22
0
 def _go(hidden_dim):
     with deepspeed.zero.Init(enabled=zero_stage == 3,
                              config_dict_or_path=ds_config):
         model = SimpleModel(hidden_dim, nlayers=78)
     print('total number of parameters:',
           sum([p.numel() for p in model.parameters()]))
     see_memory_usage('pre-init', force=True)
     model, _, _, _ = deepspeed.initialize(model=model, config=ds_config)
     see_memory_usage('post-init', force=True)
     data_loader = random_dataloader(model=model,
                                     total_samples=50,
                                     hidden_dim=hidden_dim,
                                     device=model.device,
                                     dtype=torch.half)
     print(f"optimizer={model.optimizer}")
     for batch in data_loader:
         model(batch[0], batch[1])
     see_memory_usage('post-fwds', force=True)
Ejemplo n.º 23
0
def checkpoint_correctness_verification(args,
                                        model,
                                        hidden_dim,
                                        tmpdir,
                                        load_optimizer_states=False,
                                        load_lr_scheduler_states=False,
                                        fp16=True):
    dtype = torch.half if fp16 else torch.float32
    ds_model, _, _,_ = deepspeed.initialize(args=args,
                                            model=model,
                                            model_parameters=model.parameters())
    data_loader = random_dataloader(model=ds_model,
                                    total_samples=50,
                                    hidden_dim=hidden_dim,
                                    device=ds_model.device,
                                    dtype=dtype)
    for n, batch in enumerate(data_loader):
        loss = ds_model(batch[0], batch[1])
        ds_model.backward(loss)
        ds_model.step()

    trained_model = ds_model

    save_folder = os.path.join(tmpdir, 'saved_checkpoint')
    save_tag = '1'

    trained_model.save_checkpoint(save_folder, save_tag)

    loaded_model, _, _,_ = deepspeed.initialize(args=args,
                                            model=model,
                                            model_parameters=model.parameters())

    loaded_model.load_checkpoint(save_folder,
                                 save_tag,
                                 load_optimizer_states=load_optimizer_states,
                                 load_lr_scheduler_states=load_lr_scheduler_states)

    compare_model_states(trained_model, loaded_model)

    if load_optimizer_states:
        compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16)

    if load_lr_scheduler_states:
        compare_lr_scheduler_states(trained_model, loaded_model)
Ejemplo n.º 24
0
    def _test_zero_static_scale(args):
        hidden_dim = 10
        model = SimpleModel(hidden_dim)
        model, optim, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())

        # Ensure the static scaler is configured.
        assert optim.dynamic_loss_scale == False
        assert optim.loss_scaler.loss_scale == 138.

        # Now make sure things work..
        data_loader = random_dataloader(model=model,
                                        total_samples=10,
                                        hidden_dim=hidden_dim,
                                        device=model.device)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
Ejemplo n.º 25
0
 def _test_onebitlamb_exp_avg_mask(args, model, hidden_dim):
     model, optimizer, _, _ = deepspeed.initialize(
         args=args,
         model=model,
         model_parameters=optimizer_grouped_parameters)
     data_loader = random_dataloader(model=model,
                                     total_samples=50,
                                     hidden_dim=hidden_dim,
                                     device=model.device)
     for n, batch in enumerate(data_loader):
         loss = model(batch[0], batch[1])
         model.backward(loss)
         model.step()
     # Test whether the momentum mask works
     for v in optimizer.state.values():
         if v['exp_avg'].size() == mask1.size():
             assert torch.allclose(
                 v['exp_avg'],
                 v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)),
                 atol=1e-07), f"Momentum mask is not working properly"
Ejemplo n.º 26
0
    def _helper():
        parser = argparse.ArgumentParser()
        args = parser.parse_args(args='')
        args.deepscale_config = config_path
        args.local_rank = 0

        hidden_dim = 10

        model = SimpleModel(hidden_dim=hidden_dim)

        model, _, _, _ = deepspeed.initialize(args=args, model=model)
        data_loader = random_dataloader(model=model,
                                        total_samples=5,
                                        hidden_dim=hidden_dim,
                                        device=model.device)
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            with pytest.raises(AssertionError):
                model.backward(loss)
            with pytest.raises(AssertionError):
                model.step()
Ejemplo n.º 27
0
    def _test_stage2_find_unused_parameters(args, model, hidden_dim):
        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())

        data_loader = random_dataloader(model=model,
                                        total_samples=10,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        def _loop():
            for n, batch in enumerate(data_loader):
                loss = model(batch[0], batch[1])
                model.backward(loss)
                model.step()

        if not find_unused_parameters:
            with pytest.raises(AssertionError) as e:
                _loop()
            assert e.value.args and 'find_unused_parameters' in e.value.args[0]
        else:
            _loop()
Ejemplo n.º 28
0
def checkpoint_correctness_verification(args,
                                        model,
                                        hidden_dim,
                                        load_optimizer_states=True):

    ds_model, _, _,_ = deepspeed.initialize(args=args,
                                            model=model,
                                            model_parameters=model.parameters())
    data_loader = random_dataloader(model=ds_model,
                                    total_samples=50,
                                    hidden_dim=hidden_dim,
                                    device=ds_model.device)
    for n, batch in enumerate(data_loader):
        loss = ds_model(batch[0], batch[1])
        ds_model.backward(loss)
        ds_model.step()

    trained_model = ds_model

    save_folder = 'saved_checkpoint'
    save_tag = '1'

    trained_model.save_checkpoint(save_folder, save_tag)

    loaded_model, _, _,_ = deepspeed.initialize(args=args,
                                            model=model,
                                            model_parameters=model.parameters())

    loaded_model.load_checkpoint(save_folder,
                                 save_tag,
                                 load_optimizer_states=load_optimizer_states)

    if load_optimizer_states:
        compare_optimizer_states(trained_model, loaded_model, hidden_dim)
    else:
        compare_model_states(trained_model, loaded_model)
Ejemplo n.º 29
0
    def _test_onebitlamb_checkpointing(mask1, mask2, args, model, hidden_dim):
        model_1, optimizer_1, _, _ = deepspeed.initialize(
            args=args,
            model=model,
            model_parameters=optimizer_grouped_parameters_1)
        data_loader = random_dataloader(model=model_1,
                                        total_samples=10,
                                        hidden_dim=hidden_dim,
                                        device=model_1.device)
        for n, batch in enumerate(data_loader):
            loss = model_1(batch[0], batch[1])
            model_1.backward(loss)
            model_1.step()
        # Test whether momentum mask still exist after saving checkpoint
        assert optimizer_1.optimizer.lamb_freeze_key is True
        mask1 = mask1.to(
            device=optimizer_1.param_groups[0]['exp_avg_mask'].device)
        assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'],
                              mask1,
                              atol=1e-07), f"Incorrect momentum mask"
        scaling_coeff_1 = []
        for v in optimizer_1.state.values():
            assert 'scaling_coeff' in v, f"Incorrect scaling_coeff"
            scaling_coeff_1.append(v['scaling_coeff'])
        save_folder = os.path.join(tmpdir, 'saved_checkpoint')
        model_1.save_checkpoint(save_folder, tag=None)
        assert torch.allclose(
            optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07
        ), f"Momentum mask should not change after saving checkpoint"

        model_2, optimizer_2, _, _ = deepspeed.initialize(
            args=args,
            model=model,
            model_parameters=optimizer_grouped_parameters_2)
        # Test whether momentum mask stays the same after loading checkpoint
        mask2 = mask2.to(
            device=optimizer_2.param_groups[0]['exp_avg_mask'].device)
        assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'],
                              mask2,
                              atol=1e-07), f"Incorrect momentum mask"
        model_2.load_checkpoint(save_folder,
                                tag=None,
                                load_optimizer_states=True,
                                load_lr_scheduler_states=True)
        assert torch.allclose(
            optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07
        ), f"Momentum mask should not change after loading checkpoint"
        # Test whether worker&server error is resetted
        assert len(optimizer_2.optimizer.worker_errors
                   ) == 0, f"Incorrect worker error"
        assert len(optimizer_2.optimizer.server_errors
                   ) == 0, f"Incorrect server error"
        # Test whether scaling_coeffs is loaded correctly
        scaling_coeff_2 = []
        for v in optimizer_2.state.values():
            assert 'scaling_coeff' in v, f"Incorrect scaling_coeff"
            scaling_coeff_2.append(v['scaling_coeff'])
        assert list(sorted(scaling_coeff_2)) == list(
            sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs"
        assert optimizer_2.optimizer.lamb_freeze_key is True

        model_3, optimizer_3, _, _ = deepspeed.initialize(
            args=args,
            model=model,
            model_parameters=optimizer_grouped_parameters_3)
        optimizer_3.optimizer.freeze_step = 20
        data_loader = random_dataloader(model=model_3,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model_3.device)
        for n, batch in enumerate(data_loader):
            loss = model_3(batch[0], batch[1])
            model_3.backward(loss)
            model_3.step()
        assert optimizer_3.optimizer.lamb_freeze_key is True
        # Test whether momentum mask stays the same after loading checkpoint
        assert 'exp_avg_mask' not in optimizer_3.param_groups[
            0], f"Incorrect momentum mask"
        model_3.load_checkpoint(save_folder,
                                tag=None,
                                load_optimizer_states=True,
                                load_lr_scheduler_states=True)
        assert 'exp_avg_mask' not in optimizer_3.param_groups[
            0], f"Momentum mask should not change after loading checkpoint"
        # Test whether worker&server error is resetted
        assert len(optimizer_3.optimizer.worker_errors
                   ) == 0, f"Incorrect worker error"
        assert len(optimizer_3.optimizer.server_errors
                   ) == 0, f"Incorrect server error"
        # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are resetted
        for v in optimizer_3.state.values():
            assert v[
                'lamb_coeff_freeze'] == 0.0, f"Incorrect lamb_coeff_freeze"
            assert v['last_factor'] == 1.0, f"Incorrect last_factor"
            assert 'scaling_coeff' not in v, f"Incorrect scaling_coeff"
        assert optimizer_3.optimizer.lamb_freeze_key is False
Ejemplo n.º 30
0
    def _test_zero_to_fp32():
        class MyModel(torch.nn.Module):
            def __init__(self, hidden_dim, n_layers):
                super().__init__()
                self.ll = torch.nn.ModuleList(
                    torch.nn.Linear(hidden_dim, hidden_dim)
                    for i in range(n_layers))
                self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

            def forward(self, x, y):
                hidden = x
                for l in self.ll:
                    hidden = l(hidden)
                return self.cross_entropy_loss(hidden, y)

        args = args_from_dict(tmpdir, config_dict)
        hidden_dim = 3

        world_size = dist.get_world_size()
        n_layers = world_size * 2
        model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers)

        optim_groups = [
            {
                "params": [l.weight for l in model.ll],
                "weight_decay": 0.01,
            },
            {
                "params": [l.bias for l in model.ll],
                "weight_decay": 0.0
            },
        ]
        optim = torch.optim.SGD(optim_groups, lr=0.1)

        model, _, _, _ = deepspeed.initialize(
            args=args,
            model=model,
            model_parameters=model.parameters(),
            optimizer=optim,
        )
        data_loader = random_dataloader(model=model,
                                        total_samples=16,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for i, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

        model.save_checkpoint(tmpdir)

        # make sure all sides saved it
        dist.barrier()

        if zero_stage == 3:
            with deepspeed.zero.GatheredParameters(list(
                    model.module.parameters(recurse=True)),
                                                   modifier_rank=None):
                pass  # this forces gathering the model

        #dump_state_dict(model)

        orig_state_dict = {}
        for name, param in model.module.named_parameters():
            orig_state_dict[name] = param.detach().cpu()

        if dist.get_rank() == 0:
            fp32_model = load_state_dict_from_zero_checkpoint(
                model.module, tmpdir)
            #dump_state_dict(fp32_model)

            fp32_state_dict = fp32_model.state_dict()
            for name in orig_state_dict.keys():
                # float() workaround for torch<1.6
                assert torch.allclose(orig_state_dict[name].float(),
                                      fp32_state_dict[name].float())