Ejemplo n.º 1
0
def initialize_distributed(args):
    """Initialize torch.distributed."""

    # Manually set the device ids.
    device = args.rank % torch.cuda.device_count()
    if args.local_rank is not None:
        device = args.local_rank
    torch.cuda.set_device(device)
    # Call the init process
    init_method = 'tcp://'
    master_ip = os.getenv('MASTER_ADDR', 'localhost')
    master_port = os.getenv('MASTER_PORT', '6000')
    init_method += master_ip + ':' + master_port
    torch.distributed.init_process_group(backend=args.distributed_backend,
                                         world_size=args.world_size,
                                         rank=args.rank,
                                         init_method=init_method)

    # Set the model-parallel / data-parallel communicators.
    mpu.initialize_model_parallel(args.model_parallel_size)
    groups.initialize(ep_size=args.expert_parallel_size, mpu=mpu)

    # Optional DeepSpeed Activation Checkpointing Features
    #
    if args.deepspeed and args.deepspeed_activation_checkpointing:
        set_deepspeed_activation_checkpointing(args)
Ejemplo n.º 2
0
    def _test_moe(args, hidden_dim, ep_size):
        # E+D -- ep_size = 2
        # E only -- ep_size = 4
        #groups.initialize_model_parallel(1)
        #groups.initialize_expert_parallel(2)
        groups.initialize(ep_size=ep_size)
        model = SimpleMoEModel(hidden_dim)
        optimizer = torch.optim.AdamW(params=model.parameters())
        model, _, _, _ = deepspeed.initialize(args=args,
                                              model=model,
                                              optimizer=optimizer,
                                              dist_init_required=False)
        #dist_init_required=False -- parameterize to True/False?

        assert dist.get_world_size() == groups.get_data_parallel_world_size(
        ), "incorrect data parallel world size"
        assert ep_size == groups.get_expert_parallel_world_size(
            groups.get_max_expert_size_name(
            )), "incorrect expert parallel world size"

        data_loader = sequence_dataloader(model=model,
                                          total_samples=50,
                                          hidden_dim=hidden_dim,
                                          device=model.device)

        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()
Ejemplo n.º 3
0
 def _helper(args):
     groups.initialize(ep_size=ep_size)
     models = [SimpleMoEModel(hidden_dim=hidden_dim) for _ in range(2)]
     optimizers = [torch.optim.AdamW(params=model.parameters()) for model in models]
     checkpoint_correctness_verification(args,
                                         models=models,
                                         hidden_dim=hidden_dim,
                                         tmpdir=tmpdir,
                                         load_optimizer_states=True,
                                         load_lr_scheduler_states=False,
                                         fp16=config_dict["fp16"]["enabled"],
                                         empty_tag=True,
                                         base_optimizers=optimizers,
                                         seq_dataloader=True)
Ejemplo n.º 4
0
 def _helper(args):
     groups.initialize(ep_size=ep_size)
     models = [
         SimpleMoEModel(hidden_dim=hidden_dim, num_experts=ep_size)
         for _ in range(2)
     ]
     params = [
         split_params_into_different_moe_groups_for_optimizer(
             create_param_groups(model)) for model in models
     ]
     optimizers = [torch.optim.AdamW(params=param) for param in params]
     checkpoint_correctness_verification(
         args,
         models=models,
         hidden_dim=hidden_dim,
         tmpdir=tmpdir,
         load_optimizer_states=load_optim_states,
         load_lr_scheduler_states=False,
         fp16=config_dict["fp16"]["enabled"],
         empty_tag=True,
         base_optimizers=optimizers,
         seq_dataloader=True)