Пример #1
0
def test_reduce_scatter_coalesced_single_input():
    input = torch.full((6, ),
                       dist.get_rank(),
                       dtype=torch.half,
                       device=torch.cuda.current_device())

    (output, ) = reduce_scatter_coalesced([input], dist.get_world_group())

    assert output.shape == (3, )
    assert torch.allclose(output, torch.full_like(output, 0.5))
Пример #2
0
    def has_overflow(self, params, has_moe_params=None):
        if has_moe_params is None:
            has_moe_params = self.has_moe_params
        overflow = self.has_overflow_serial(params)
        # Since each model parallel GPU carries only part of the model,
        # make sure overflow flag is synced across all the model parallel GPUs
        overflow_gpu = torch.cuda.ByteTensor([overflow])
        # deepspeeed.comm.all_reduce(overflow_gpu,
        #                             op=deepspeed.comm.ReduceOp.MAX,
        #                             group=mpu.get_model_parallel_group())
        if has_moe_params:
            # All reduce this across expert_parallel_group, so that if an expert
            # overflows, we detect it here
            dist.all_reduce(overflow_gpu,
                            op=dist.ReduceOp.MAX,
                            group=groups._get_max_expert_parallel_group())
        if self.zero_reduce_scatter:
            dist.all_reduce(overflow_gpu,
                            op=dist.ReduceOp.MAX,
                            group=dist.get_world_group())
        elif self.mpu is not None:
            if self.deepspeed is not None:
                using_pipeline = hasattr(self.deepspeed,
                                         'pipeline_enable_backward_allreduce')
                if (using_pipeline
                        and self.deepspeed.pipeline_enable_backward_allreduce
                        is False) or (
                            not using_pipeline and
                            self.deepspeed.enable_backward_allreduce is False):
                    dist.all_reduce(overflow_gpu,
                                    op=dist.ReduceOp.MAX,
                                    group=self.mpu.get_data_parallel_group())
            dist.all_reduce(overflow_gpu,
                            op=dist.ReduceOp.MAX,
                            group=self.mpu.get_model_parallel_group())
        elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False:
            dist.all_reduce(overflow_gpu,
                            op=dist.ReduceOp.MAX,
                            group=dist.get_world_group())

        overflow = overflow_gpu[0].item()
        return bool(overflow)
Пример #3
0
def test_reduce_scatter_coalesced_tensor_smaller_than_world_sz():
    input = torch.zeros((1, ),
                        dtype=torch.half,
                        device=torch.cuda.current_device())

    (output, ) = reduce_scatter_coalesced([input], dist.get_world_group())

    if dist.get_rank() == 0:
        assert output.shape == (1, )
        assert torch.allclose(output, torch.zeros_like(output))
    elif dist.get_rank() == 1:
        assert output.shape == (0, )
Пример #4
0
def test_reduce_scatter_coalesced_two_inputs():
    tensor_kwargs = {
        "device": torch.cuda.current_device(),
        "dtype": torch.half
    }
    inputs = [
        dist.get_rank() * torch.arange(0, 6, **tensor_kwargs),
        dist.get_rank() * torch.arange(6, 9, **tensor_kwargs),
    ]

    output1, output2 = reduce_scatter_coalesced(inputs, dist.get_world_group())

    if dist.get_rank() == 0:
        assert output1.shape == (3, )
        assert torch.allclose(output1, torch.arange(0, 3, **tensor_kwargs) / 2)
        assert output2.shape == (2, )
        assert torch.allclose(output2, torch.arange(6, 8, **tensor_kwargs) / 2)
    elif dist.get_rank() == 1:
        assert output1.shape == (3, )
        assert torch.allclose(output1, torch.arange(3, 6, **tensor_kwargs) / 2)
        assert output2.shape == (1, )
        assert torch.allclose(output2, torch.arange(8, 9, **tensor_kwargs) / 2)
Пример #5
0
def top1gating(
        logits: Tensor,
        capacity_factor: float,
        min_capacity: int,
        used_token: Tensor = None,
        noisy_gate_policy: Optional[str] = None,
        drop_tokens: bool = True,
        use_rts: bool = True,
        use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """Implements Top1Gating on logits."""
    if noisy_gate_policy == 'RSample':
        logits_w_noise = logits + gumbel_rsample(logits.shape,
                                                 device=logits.device)
    # everything is in fp32 in this function
    gates = F.softmax(logits, dim=1)

    capacity = _capacity(gates, torch.tensor(capacity_factor),
                         torch.tensor(min_capacity))

    # Create a mask for 1st's expert per token
    # noisy gating
    indices1_s = torch.argmax(
        logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)
    num_experts = int(gates.shape[1])
    mask1 = F.one_hot(indices1_s, num_classes=num_experts)

    # mask only used tokens
    if used_token is not None:
        mask1 = einsum("s,se->se", used_token, mask1)

    # gating decisions
    exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')

    # if we don't want to drop any tokens
    if not drop_tokens:
        new_capacity = torch.max(exp_counts).to(logits.device)
        dist.all_reduce(new_capacity,
                        op=dist.ReduceOp.MAX,
                        group=dist.get_world_group())
        capacity = new_capacity

    # Compute l_aux
    me = torch.mean(gates, dim=0)
    ce = torch.mean(mask1.float(), dim=0)
    l_aux = torch.sum(me * ce) * num_experts

    # Random Token Selection
    if use_rts:
        uniform = exp_selection_uniform_map.get(logits.device)
        if uniform is None:
            uniform = torch.distributions.uniform.Uniform(
                low=torch.tensor(0.0, device=logits.device),
                high=torch.tensor(1.0, device=logits.device)).rsample
            exp_selection_uniform_map[logits.device] = uniform

        mask1_rand = mask1 * uniform(mask1.shape)
    else:
        mask1_rand = mask1

    assert logits.shape[
        0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."

    top_idx = _top_idx(mask1_rand, capacity)

    new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
    mask1 = new_mask1

    if use_tutel:
        # Tutel doesn't support index values masked with zero
        # so we need to replace masked indices with -1
        indices_mask = mask1.sum(dim=1) * num_experts - 1
        indices1_s = torch.min(indices1_s, indices_mask)

    # Compute locations in capacity buffer
    if use_tutel:
        locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
    else:
        locations1 = torch.cumsum(mask1, dim=0) - 1

    if use_tutel:
        gates1_s = (gates * mask1).sum(dim=1)
        locations1_s = torch.sum(locations1 * mask1, dim=1)
        return l_aux, capacity, num_experts, [
            indices1_s,
        ], [
            locations1_s,
        ], [
            gates1_s,
        ], exp_counts

    # Store the capacity location for each token
    locations1_s = torch.sum(locations1 * mask1, dim=1)

    # Normalize gate probabilities
    mask1_float = mask1.float()
    gates = gates * mask1_float

    locations1_sc = _one_hot_to_float(locations1_s, capacity)
    combine_weights = einsum("se,sc->sec", gates, locations1_sc)

    dispatch_mask = combine_weights.bool()

    return l_aux, combine_weights, dispatch_mask, exp_counts