Example #1
0
 def _test_lamb_legacy_optimizer_step(args, hidden_dim, fused_lamb_legacy):
     # initialize MoE
     groups.initialize_model_parallel(1)
     groups.initialize_expert_parallel(2)
     model = SimpleMoEModel(hidden_dim)
     engine, optimizer, _, _ = deepspeed.initialize(
         args=args,
         model=model,
         model_parameters=model.parameters(),
         dist_init_required=False)
     monkeypatch.setattr(optimizer, 'unscale_and_clip_grads',
                         mock_unscale_and_clip_grads)
     optimizer.fused_lamb_legacy = fused_lamb_legacy
     data_loader = sequence_dataloader(model=engine,
                                       total_samples=50,
                                       hidden_dim=hidden_dim,
                                       device=engine.device)
     for n, batch in enumerate(data_loader):
         loss = engine(batch[0], batch[1])
         engine.backward(loss)
         engine.step()
Example #2
0
    def _test_clip_grad_norm_() -> None:
        param1 = torch.nn.Parameter(torch.Tensor([0]))
        param1.grad = torch.Tensor([1])
        param2 = torch.nn.Parameter(torch.Tensor([0]))
        param2.grad = torch.Tensor([dist.get_rank() + 1])
        # param2 is now MoE parameter
        param2.allreduce = False

        parameters = [param1, param2]

        groups.initialize_model_parallel(1)
        groups.initialize_expert_parallel(2)

        norm = ds_utils.clip_grad_norm_(parameters, max_norm=0.1)
        norm = torch.Tensor([norm]).to(dist.get_rank())

        world_size = dist.get_world_size()
        gathered_norm = [torch.zeros(1).cuda() for i in range(world_size)]

        torch.distributed.all_gather(gathered_norm, norm)

        assert gathered_norm[0] == gathered_norm[1], "norm at rank 0 does not match the norm at rank 1"
Example #3
0
    def _test_CheckOverflow(check_using_norm: bool):
        groups.initialize_model_parallel(1)
        groups.initialize_expert_parallel(2)

        param1 = torch.nn.Parameter(torch.Tensor([0]))
        param1.grad = torch.Tensor([1])
        param2 = torch.nn.Parameter(torch.Tensor([0]))
        if dist.get_rank() == 0:
            param2.grad = torch.Tensor([1])
        else:
            param2.grad = torch.Tensor([float("inf")])
        param2.allreduce = False
        # param2 is now MoE parameter
        parameters = [param1, param2]
        if check_using_norm:
            grads_group_flat = [_flatten_dense_tensors([p.grad for p in parameters])]
            norm = ds_utils.get_weight_norm(grads_group_flat)
            overflow_checker = ds_utils.CheckOverflow([parameters])
            overflow = overflow_checker.check_using_norm([norm], reduce_overflow=False)
        else:
            overflow_checker = ds_utils.CheckOverflow([parameters])
            overflow = overflow_checker.check()
        assert overflow