Exemple #1
0
    def _test_moe(args, hidden_dim, ep_size):
        # E+D -- ep_size = 2
        # E only -- ep_size = 4
        #groups.initialize_model_parallel(1)
        #groups.initialize_expert_parallel(2)
        groups.initialize(ep_size=ep_size)
        model = SimpleMoEModel(hidden_dim)
        optimizer = torch.optim.AdamW(params=model.parameters())
        model, _, _, _ = deepspeed.initialize(args=args,
                                              model=model,
                                              optimizer=optimizer,
                                              dist_init_required=False)
        #dist_init_required=False -- parameterize to True/False?

        assert dist.get_world_size() == groups.get_data_parallel_world_size(
        ), "incorrect data parallel world size"
        assert ep_size == groups.get_expert_parallel_world_size(
        ), "incorrect expert parallel world size"

        data_loader = sequence_dataloader(model=model,
                                          total_samples=50,
                                          hidden_dim=hidden_dim,
                                          device=model.device)

        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
Exemple #2
0
 def _test_lamb_legacy_optimizer_step(args, hidden_dim, fused_lamb_legacy):
     # initialize MoE
     groups.initialize_model_parallel(1)
     groups.initialize_expert_parallel(2)
     model = SimpleMoEModel(hidden_dim)
     engine, optimizer, _, _ = deepspeed.initialize(args=args,
                                            model=model,
                                            model_parameters=model.parameters(),
                                            dist_init_required=False)
     monkeypatch.setattr(optimizer,
                         'unscale_and_clip_grads',
                         mock_unscale_and_clip_grads)
     optimizer.fused_lamb_legacy = fused_lamb_legacy
     data_loader = sequence_dataloader(model=engine,
                                       total_samples=50,
                                       hidden_dim=hidden_dim,
                                       device=engine.device)
     for n, batch in enumerate(data_loader):
         loss = engine(batch[0], batch[1])
         engine.backward(loss)
         engine.step()