Ejemplo n.º 1
0
    def test_load_state_dict_from_zero_checkpoint(self, stage):
        # test that we can load fp32 weights directly from the zero checkpoint into the current model

        output_dir = self.get_auto_remove_tmp_dir()  # "./xxx", after=False, before=False)

        ds_config_dict = self.get_config_dict(stage)

        kwargs = dict(
            output_dir=output_dir,
            train_len=4,
            per_device_train_batch_size=4,
            num_train_epochs=1,
            save_strategy="steps",
            save_steps=1,
            learning_rate=0.1,
            fp16=True,
            deepspeed=ds_config_dict,
        )

        with mockenv_context(**self.dist_env_1_gpu):
            trainer = get_regression_trainer(**kwargs)
            trainer.train()
            (a, b) = trainer.model.a.item(), trainer.model.b.item()
            state = dataclasses.asdict(trainer.state)

            checkpoint_dir = get_last_checkpoint(output_dir)
            model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)

            (a1, b1) = model.a.item(), model.b.item()
            state1 = dataclasses.asdict(trainer.state)
            self.assertEqual(a, a1)
            self.assertEqual(b, b1)
            self.check_trainer_state_are_the_same(state, state1)
Ejemplo n.º 2
0
    def _test_zero_to_fp32():
        class MyModel(torch.nn.Module):
            def __init__(self, hidden_dim, n_layers):
                super().__init__()
                self.ll = torch.nn.ModuleList(
                    torch.nn.Linear(hidden_dim, hidden_dim)
                    for i in range(n_layers))
                self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

            def forward(self, x, y):
                hidden = x
                for l in self.ll:
                    hidden = l(hidden)
                return self.cross_entropy_loss(hidden, y)

        hidden_dim = 3

        world_size = dist.get_world_size()
        n_layers = world_size * 2
        model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers)

        optim_groups = [
            {
                "params": [l.weight for l in model.ll],
                "weight_decay": 0.01,
            },
            {
                "params": [l.bias for l in model.ll],
                "weight_decay": 0.0
            },
        ]
        optim = torch.optim.SGD(optim_groups, lr=0.1)

        model, _, _, _ = deepspeed.initialize(
            model=model,
            model_parameters=model.parameters(),
            optimizer=optim,
            config=config_dict)
        data_loader = random_dataloader(model=model,
                                        total_samples=16,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for i, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

        model.save_checkpoint(tmpdir)

        # make sure all sides saved it
        dist.barrier()

        if zero_stage == 3:
            with deepspeed.zero.GatheredParameters(list(
                    model.module.parameters(recurse=True)),
                                                   modifier_rank=None):
                pass  # this forces gathering the model

        #dump_state_dict(model)

        orig_state_dict = {}
        for name, param in model.module.named_parameters():
            orig_state_dict[name] = param.detach().cpu()

        if dist.get_rank() == 0:
            fp32_model = load_state_dict_from_zero_checkpoint(
                model.module, tmpdir)
            #dump_state_dict(fp32_model)

            fp32_state_dict = fp32_model.state_dict()
            for name in orig_state_dict.keys():
                # float() workaround for torch<1.6
                assert torch.allclose(orig_state_dict[name].float(),
                                      fp32_state_dict[name].float())
Ejemplo n.º 3
0
    def _test_zero_to_fp32():
        class MyModel(torch.nn.Module):
            def __init__(self, hidden_dim, n_layers):
                super().__init__()
                # to reproduce https://github.com/microsoft/DeepSpeed/pull/1372 it is important that
                # the number of total elements is uneven:
                # (1) 4 layers of 3*(3+1)=12 elements each, 48 in total
                self.ll = torch.nn.ModuleList(
                    torch.nn.Linear(hidden_dim, hidden_dim)
                    for i in range(n_layers))
                # (2) the following adds 4+1=5 elements
                self.classifier = torch.nn.Linear(4, 1)
                # total 48+5=53 (uneven as desired) elements
                self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

            def forward(self, x, y):
                hidden = x
                for l in self.ll:
                    hidden = l(hidden)
                return self.cross_entropy_loss(hidden, y)

        hidden_dim = 3  # do not change

        world_size = dist.get_world_size()
        # we want at least 2x layers as there are gpus to trigger round_robin_fp16_groups reshuffle in zero2
        n_layers = world_size * 2
        model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers)

        model, _, _, _ = deepspeed.initialize(
            config=config_dict,
            model=model,
            model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=16,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for i, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

        model.save_checkpoint(tmpdir)

        # make sure all sides saved it
        dist.barrier()

        if zero_stage == 3:
            with deepspeed.zero.GatheredParameters(list(
                    model.module.parameters(recurse=True)),
                                                   modifier_rank=None):
                pass  # this forces gathering the model

        #dump_state_dict(model)

        orig_state_dict = {}
        for name, param in model.module.named_parameters():
            orig_state_dict[name] = param.detach().cpu()

        if dist.get_rank() == 0:
            fp32_model = load_state_dict_from_zero_checkpoint(
                model.module, tmpdir)
            #dump_state_dict(fp32_model)

            fp32_state_dict = fp32_model.state_dict()
            for name in orig_state_dict.keys():
                # float() workaround for torch<1.6
                assert torch.allclose(orig_state_dict[name].float(),
                                      fp32_state_dict[name].float())
Ejemplo n.º 4
0
    def _test_zero_to_fp32():
        class MyModel(torch.nn.Module):
            def __init__(self, hidden_dim, n_layers):
                super().__init__()
                self.ll = torch.nn.ModuleList(
                    torch.nn.Linear(hidden_dim, hidden_dim)
                    for i in range(n_layers))
                self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

            def forward(self, x, y):
                hidden = x
                for l in self.ll:
                    hidden = l(hidden)
                return self.cross_entropy_loss(hidden, y)

        args = args_from_dict(tmpdir, config_dict)
        hidden_dim = 2

        world_size = dist.get_world_size()
        # we want at least 2x layers as there are gpus to trigger round_robin_fp16_groups reshuffle in zero2
        n_layers = world_size * 2
        model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers)

        model, _, _, _ = deepspeed.initialize(
            args=args, model=model, model_parameters=model.parameters())
        data_loader = random_dataloader(model=model,
                                        total_samples=16,
                                        hidden_dim=hidden_dim,
                                        device=model.device)

        for i, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

        model.save_checkpoint(tmpdir)

        # make sure all sides saved it
        dist.barrier()

        def dump_state_dict(model):
            if dist.get_rank() != 0:
                return
            for name, param in model.named_parameters():
                print(f"{name} {param}")

        if zero_stage == 3:
            with deepspeed.zero.GatheredParameters(list(
                    model.module.parameters(recurse=True)),
                                                   modifier_rank=None):
                pass  # this forces gathering the model

        #dump_state_dict(model)

        orig_state_dict = {}
        for name, param in model.module.named_parameters():
            orig_state_dict[name] = param.detach().cpu()
        print(orig_state_dict)

        fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
        #dump_state_dict(fp32_model)

        fp32_state_dict = fp32_model.state_dict()
        for name in orig_state_dict.keys():
            # float() workaround for torch<1.6
            assert torch.allclose(orig_state_dict[name].float(),
                                  fp32_state_dict[name].float())