def test_load_state_dict_from_zero_checkpoint(self, stage): # test that we can load fp32 weights directly from the zero checkpoint into the current model output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False, before=False) ds_config_dict = self.get_config_dict(stage) kwargs = dict( output_dir=output_dir, train_len=4, per_device_train_batch_size=4, num_train_epochs=1, save_strategy="steps", save_steps=1, learning_rate=0.1, fp16=True, deepspeed=ds_config_dict, ) with mockenv_context(**self.dist_env_1_gpu): trainer = get_regression_trainer(**kwargs) trainer.train() (a, b) = trainer.model.a.item(), trainer.model.b.item() state = dataclasses.asdict(trainer.state) checkpoint_dir = get_last_checkpoint(output_dir) model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) (a1, b1) = model.a.item(), model.b.item() state1 = dataclasses.asdict(trainer.state) self.assertEqual(a, a1) self.assertEqual(b, b1) self.check_trainer_state_are_the_same(state, state1)
def _test_zero_to_fp32(): class MyModel(torch.nn.Module): def __init__(self, hidden_dim, n_layers): super().__init__() self.ll = torch.nn.ModuleList( torch.nn.Linear(hidden_dim, hidden_dim) for i in range(n_layers)) self.cross_entropy_loss = torch.nn.CrossEntropyLoss() def forward(self, x, y): hidden = x for l in self.ll: hidden = l(hidden) return self.cross_entropy_loss(hidden, y) hidden_dim = 3 world_size = dist.get_world_size() n_layers = world_size * 2 model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers) optim_groups = [ { "params": [l.weight for l in model.ll], "weight_decay": 0.01, }, { "params": [l.bias for l in model.ll], "weight_decay": 0.0 }, ] optim = torch.optim.SGD(optim_groups, lr=0.1) model, _, _, _ = deepspeed.initialize( model=model, model_parameters=model.parameters(), optimizer=optim, config=config_dict) data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) for i, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() model.save_checkpoint(tmpdir) # make sure all sides saved it dist.barrier() if zero_stage == 3: with deepspeed.zero.GatheredParameters(list( model.module.parameters(recurse=True)), modifier_rank=None): pass # this forces gathering the model #dump_state_dict(model) orig_state_dict = {} for name, param in model.module.named_parameters(): orig_state_dict[name] = param.detach().cpu() if dist.get_rank() == 0: fp32_model = load_state_dict_from_zero_checkpoint( model.module, tmpdir) #dump_state_dict(fp32_model) fp32_state_dict = fp32_model.state_dict() for name in orig_state_dict.keys(): # float() workaround for torch<1.6 assert torch.allclose(orig_state_dict[name].float(), fp32_state_dict[name].float())
def _test_zero_to_fp32(): class MyModel(torch.nn.Module): def __init__(self, hidden_dim, n_layers): super().__init__() # to reproduce https://github.com/microsoft/DeepSpeed/pull/1372 it is important that # the number of total elements is uneven: # (1) 4 layers of 3*(3+1)=12 elements each, 48 in total self.ll = torch.nn.ModuleList( torch.nn.Linear(hidden_dim, hidden_dim) for i in range(n_layers)) # (2) the following adds 4+1=5 elements self.classifier = torch.nn.Linear(4, 1) # total 48+5=53 (uneven as desired) elements self.cross_entropy_loss = torch.nn.CrossEntropyLoss() def forward(self, x, y): hidden = x for l in self.ll: hidden = l(hidden) return self.cross_entropy_loss(hidden, y) hidden_dim = 3 # do not change world_size = dist.get_world_size() # we want at least 2x layers as there are gpus to trigger round_robin_fp16_groups reshuffle in zero2 n_layers = world_size * 2 model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers) model, _, _, _ = deepspeed.initialize( config=config_dict, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) for i, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() model.save_checkpoint(tmpdir) # make sure all sides saved it dist.barrier() if zero_stage == 3: with deepspeed.zero.GatheredParameters(list( model.module.parameters(recurse=True)), modifier_rank=None): pass # this forces gathering the model #dump_state_dict(model) orig_state_dict = {} for name, param in model.module.named_parameters(): orig_state_dict[name] = param.detach().cpu() if dist.get_rank() == 0: fp32_model = load_state_dict_from_zero_checkpoint( model.module, tmpdir) #dump_state_dict(fp32_model) fp32_state_dict = fp32_model.state_dict() for name in orig_state_dict.keys(): # float() workaround for torch<1.6 assert torch.allclose(orig_state_dict[name].float(), fp32_state_dict[name].float())
def _test_zero_to_fp32(): class MyModel(torch.nn.Module): def __init__(self, hidden_dim, n_layers): super().__init__() self.ll = torch.nn.ModuleList( torch.nn.Linear(hidden_dim, hidden_dim) for i in range(n_layers)) self.cross_entropy_loss = torch.nn.CrossEntropyLoss() def forward(self, x, y): hidden = x for l in self.ll: hidden = l(hidden) return self.cross_entropy_loss(hidden, y) args = args_from_dict(tmpdir, config_dict) hidden_dim = 2 world_size = dist.get_world_size() # we want at least 2x layers as there are gpus to trigger round_robin_fp16_groups reshuffle in zero2 n_layers = world_size * 2 model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers) model, _, _, _ = deepspeed.initialize( args=args, model=model, model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) for i, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) model.backward(loss) model.step() model.save_checkpoint(tmpdir) # make sure all sides saved it dist.barrier() def dump_state_dict(model): if dist.get_rank() != 0: return for name, param in model.named_parameters(): print(f"{name} {param}") if zero_stage == 3: with deepspeed.zero.GatheredParameters(list( model.module.parameters(recurse=True)), modifier_rank=None): pass # this forces gathering the model #dump_state_dict(model) orig_state_dict = {} for name, param in model.module.named_parameters(): orig_state_dict[name] = param.detach().cpu() print(orig_state_dict) fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) #dump_state_dict(fp32_model) fp32_state_dict = fp32_model.state_dict() for name in orig_state_dict.keys(): # float() workaround for torch<1.6 assert torch.allclose(orig_state_dict[name].float(), fp32_state_dict[name].float())