Пример #1
0
 def test(self):
     x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1)
     sum_of_ranks = (dist.get_world_size() *
                     (dist.get_world_size() + 1)) // 2
     result = torch.ones(1, 3).cuda() * sum_of_ranks
     dist.all_reduce(x)
     assert torch.all(x == result)
Пример #2
0
def train_cifar(model,
                config,
                num_steps=400,
                average_dp_losses=True,
                fp16=True,
                seed=123):
    with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
        ds_utils.set_random_seed(seed)

        # disable dropout
        model.eval()

        trainset = cifar_trainset(fp16=fp16)
        config['local_rank'] = dist.get_rank()

        engine, _, _, _ = deepspeed.initialize(
            config=config,
            model=model,
            model_parameters=[p for p in model.parameters()],
            training_data=trainset)

        losses = []
        for step in range(num_steps):
            loss = engine.train_batch()
            losses.append(loss.item())
            if step % 50 == 0 and dist.get_rank() == 0:
                print(f'STEP={step} LOSS={loss.item()}')

        if average_dp_losses:
            loss_tensor = torch.tensor(losses).cuda()
            dist.all_reduce(loss_tensor)
            loss_tensor /= dist.get_world_size()
            losses = loss_tensor.tolist()

    return losses
Пример #3
0
 def forward(self, input):
     output = torch.matmul(input, self.weight.transpose(-1, -2))
     if self.mp_group is not None:
         dist.all_reduce(output, group=self.mp_group)
     if self.bias is not None:
         output += self.bias
     return output
Пример #4
0
def _reduce(input_):
    """All-reduce the the input tensor across model parallel group."""
    group = g_mpu.get_model_parallel_group()

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # All-reduce.
    dist.all_reduce(input_, group=group)

    return input_
Пример #5
0
def get_full_hp_param(self, optim_state_key=None):
    reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
    if self._hp_mapping is not None:
        lp_frag_address = self._hp_mapping.lp_fragment_address
        reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start,
                                       lp_frag_address.numel)
        if optim_state_key is None:
            hp_fragment = self._hp_mapping.hp_fragment
        else:
            hp_fragment = self._hp_mapping.get_optim_state_fragment(
                optim_state_key)

        reduce_fragment.data.copy_(hp_fragment.data)
    dist.all_reduce(reduce_buffer, group=self._dp_group)
    return reduce_buffer.reshape_as(self)
Пример #6
0
    def forward(ctx,
                input,
                residual,
                residual_norm,
                bias,
                inter_w,
                inter_b,
                attn_nw,
                attn_nb,
                config,
                mp_group,
                output_b,
                output_w,
                q_scales,
                q_groups,
                merge_count,
                mlp_gemm_func,
                fused_gemm_gelu,
                vector_matmul_func,
                bias_residual_func,
                activation_func_type=ActivationFuncType.GELU):

        if config.q_int8:
            (intermediate, residual_add) = inference_cuda_module.mlp_gemm_int8(
                input, residual, bias, inter_w, inter_b, attn_nw, attn_nb,
                config.epsilon, q_scales[2], (q_groups * (2**merge_count)),
                config.pre_layer_norm)
            output = inference_cuda_module.vector_matmul_int8(
                intermediate, output_w, q_scales[3], q_groups, (merge_count))
        else:
            if attn_nw is None:
                output = fused_gemm_gelu(residual_norm, inter_w, inter_b,
                                         output_w, config.epsilon,
                                         config.pre_layer_norm, False)
            else:
                intermediate, residual_add = mlp_gemm_func(
                    input, residual, bias, inter_w, inter_b, attn_nw, attn_nb,
                    config.epsilon, config.pre_layer_norm,
                    config.mlp_after_attn, config.mlp_act_func_type)
                output = vector_matmul_func(intermediate, output_w, False)

        inference_cuda_module.residual_add(
            output, residual if config.pre_layer_norm else residual_add, input,
            output_b, bias if bias is not None else output_b, config.mp_size,
            config.mlp_after_attn, bias is not None, config.pre_layer_norm)
        if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
            dist.all_reduce(output, group=mp_group)
        return output
Пример #7
0
 def _get_norm_with_moe_layers(self, all_groups_norm):
     #all_groups_norm_old = all_groups_norm
     # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
     if self.using_pipeline:
         pg = self.deepspeed.mpu.get_data_parallel_group()
     else:
         pg = groups._get_data_parallel_group()
     scaled_norm = all_groups_norm * 1.0 / float(
         dist.get_world_size(group=pg))
     scaled_norm_tensor = torch.tensor(
         scaled_norm,
         device=self.fp32_groups_flat[0].device,
         dtype=torch.float)
     dist.all_reduce(scaled_norm_tensor, group=pg)
     all_groups_norm = scaled_norm_tensor.item()
     #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
     return all_groups_norm
Пример #8
0
    def forward(ctx, input, inter_w, inter_b, config, output_b, output_w,
                q_scales, q_groups, merge_count, mp_group, async_op):
        if config.q_int8:
            intermediate = inference_cuda_module.fused_gemm_gelu_int8(
                input, inter_w, inter_b, config.epsilon, q_scales[2],
                (q_groups * (2**merge_count)), config.pre_layer_norm)
            output = inference_cuda_module.vector_matmul_int8(
                intermediate, output_w, q_scales[3], q_groups, (merge_count))
        else:
            mlp_gemm_func = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \
                                    inference_cuda_module.fused_gemm_gelu_fp32

            output = mlp_gemm_func(input, inter_w, inter_b, output_w,
                                   config.epsilon, config.pre_layer_norm,
                                   async_op)
        if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
            dist.all_reduce(output, group=mp_group, async_op=async_op)

        return output + output_b
Пример #9
0
    def has_overflow(self, params, has_moe_params=None):
        if has_moe_params is None:
            has_moe_params = self.has_moe_params
        overflow = self.has_overflow_serial(params)
        # Since each model parallel GPU carries only part of the model,
        # make sure overflow flag is synced across all the model parallel GPUs
        overflow_gpu = torch.cuda.ByteTensor([overflow])
        # deepspeeed.comm.all_reduce(overflow_gpu,
        #                             op=deepspeed.comm.ReduceOp.MAX,
        #                             group=mpu.get_model_parallel_group())
        if has_moe_params:
            # All reduce this across expert_parallel_group, so that if an expert
            # overflows, we detect it here
            dist.all_reduce(overflow_gpu,
                            op=dist.ReduceOp.MAX,
                            group=groups._get_max_expert_parallel_group())
        if self.zero_reduce_scatter:
            dist.all_reduce(overflow_gpu,
                            op=dist.ReduceOp.MAX,
                            group=dist.get_world_group())
        elif self.mpu is not None:
            if self.deepspeed is not None:
                using_pipeline = hasattr(self.deepspeed,
                                         'pipeline_enable_backward_allreduce')
                if (using_pipeline
                        and self.deepspeed.pipeline_enable_backward_allreduce
                        is False) or (
                            not using_pipeline and
                            self.deepspeed.enable_backward_allreduce is False):
                    dist.all_reduce(overflow_gpu,
                                    op=dist.ReduceOp.MAX,
                                    group=self.mpu.get_data_parallel_group())
            dist.all_reduce(overflow_gpu,
                            op=dist.ReduceOp.MAX,
                            group=self.mpu.get_model_parallel_group())
        elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False:
            dist.all_reduce(overflow_gpu,
                            op=dist.ReduceOp.MAX,
                            group=dist.get_world_group())

        overflow = overflow_gpu[0].item()
        return bool(overflow)
Пример #10
0
def torch_sim(a):
    a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    scale = a.norm() / np.sqrt(a.numel())
    a_compressed = scale * a_sign
    a_sign = None
    worker_error = a - a_compressed
    dist.all_reduce(a_compressed)
    a_compressed.mul_(1 / dist.get_world_size())
    a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(
        2.0)
    a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
    server_scale = [
        chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list
    ]
    a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
    a_server_compressed = torch.cat([
        server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())
    ])
    rank = dist.get_rank()
    server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
    torch.cuda.synchronize()
    dist.barrier()
    return a_server_compressed, worker_error, server_error
Пример #11
0
    def test_grid_pipe_data(self):
        topo = Topo(axes=['pipe', 'data'], dims=[2, 2])
        grid = Grid(topology=topo)

        assert grid._is_grid_valid()

        rank = dist.get_rank()

        assert grid.is_first_stage == (grid.get_stage_id() == 0)
        assert grid.is_last_stage == (
            grid.get_stage_id() == grid.get_pipe_parallel_world_size() - 1)

        # Test collectives along the pipeline parallel process groups
        rank_tensor = torch.LongTensor(data=[rank]).cuda()
        dist.all_reduce(rank_tensor, group=grid.get_pipe_parallel_group())
        pipe_group = grid.pp_group
        assert torch.all(rank_tensor == sum(pipe_group))

        # Test collectives along the data parallel process groups
        rank_tensor = torch.LongTensor(data=[rank]).cuda()
        dist.all_reduce(rank_tensor, group=grid.get_data_parallel_group())
        data_group = grid.dp_group
        assert torch.all(rank_tensor == sum(data_group))
Пример #12
0
    def test(self, sequential_model, simple_config, batch_input):
        base_model = copy.deepcopy(sequential_model)
        base_input = batch_input.clone().detach()
        base_output = base_model(base_input)
        base_output = base_output
        base_params = sum(p.numel() for p in base_model.parameters())

        pipe_model = copy.deepcopy(sequential_model)
        pipe_model = PipelineModule(layers=pipe_model, num_stages=2)

        # Ensure all parameters are accounted for.
        my_params = sum(p.numel() for p in pipe_model.parameters())
        total_pipe_params = torch.LongTensor([my_params]).to('cuda')
        dist.all_reduce(total_pipe_params)
        total_pipe_params = total_pipe_params.item()
        assert total_pipe_params == base_params

        pipe_model, _, _, _ = deepspeed.initialize(
            config=simple_config,
            model=pipe_model,
            model_parameters=[p for p in pipe_model.parameters()])

        if pipe_model.is_first_stage or pipe_model.is_last_stage:
            pipe_input = base_input.clone().detach().to('cuda')
            # label 0 is meaningless
            dataset = [(pipe_input, 0)]
            loader = RepeatingLoader(dataset)
            data_iter = iter(loader)
        else:
            data_iter = None

        pipe_output = pipe_model.eval_batch(data_iter=data_iter)

        base_output = base_output.to('cpu')
        pipe_output = pipe_output.to('cpu')

        assert torch.allclose(base_output, pipe_output, atol=1e-4)
Пример #13
0
    def check_using_norm(self, norm_group, reduce_overflow=True):
        # TODO: I don't think reduce_overflow is needed if mpu is None
        overflow = -1 in norm_group
        overflow_gpu = torch.cuda.FloatTensor([overflow])
        if self.has_moe_params:
            # In this case, we need to do an all_reduce across
            # the expert_parallel_group, so that if there was
            # an overflow due to expert weights, we detect it

            # Only need to check groups.get_largest_expert_parallel_group()
            dist.all_reduce(overflow_gpu,
                            op=dist.ReduceOp.MAX,
                            group=groups._get_max_expert_parallel_group())
        if self.mpu is not None:
            dist.all_reduce(overflow_gpu,
                            op=dist.ReduceOp.MAX,
                            group=self.mpu.get_model_parallel_group())
        elif reduce_overflow:
            dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX)
            dist.barrier()
        overflow = overflow_gpu[0].item()
        return bool(overflow)
Пример #14
0
def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
    """Clips gradient norm of an iterable of parameters.

    This has been adapted from Nvidia megatron. We add norm averaging
    to consider MoE params when calculating norm as they will result
    in different norms across different ranks.

    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
    added functionality to handle model parallel parameters. Note that
    the gradients are modified in place.

    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if norm_type == inf:
        total_norm = max(p.grad.data.abs().max() for p in parameters)
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all GPUs.
        if mpu is not None:
            dist.all_reduce(total_norm_cuda,
                            op=dist.ReduceOp.MAX,
                            group=mpu.get_model_parallel_group())
        total_norm = total_norm_cuda[0].item()
    else:
        total_norm = 0
        for p in parameters:
            if mpu is not None:
                if (mpu.get_model_parallel_rank()
                        == 0) or is_model_parallel_parameter(p):
                    param_norm = p.grad.data.norm(norm_type)
                    total_norm += param_norm.item()**norm_type
            else:
                param_norm = p.grad.data.float().norm(norm_type)
                total_norm += param_norm.item()**norm_type

        # Sum across all model parallel GPUs.
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        if mpu is not None:
            dist.all_reduce(total_norm_cuda,
                            op=dist.ReduceOp.SUM,
                            group=mpu.get_model_parallel_group())
        total_norm = total_norm_cuda[0].item()**(1. / norm_type)

    # Need to average total_norm across different GPUs due to the presence of moe params
    pg = groups._get_data_parallel_group()
    scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg))

    scaled_norm_tensor = torch.cuda.FloatTensor([float(scaled_norm)])
    dist.all_reduce(scaled_norm_tensor, group=pg)
    total_norm = scaled_norm_tensor.item()

    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in parameters:
            p.grad.data.mul_(clip_coef)
    return total_norm
Пример #15
0
    def step(self, closure=None, grads=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            grads (list of tensors, optional): weight gradient to use for the
                optimizer update. If gradients have type torch.half, parameters
                are expected to be in type torch.float. (default: None)
            output params (list of tensors, optional): A reduced precision copy
                of the updated weights written out in addition to the regular
                updated weights. Have to be of same type as gradients. (default: None)
            scale (float, optional): factor to divide gradient tensor values
                by before applying to weights. (default: 1)
        """
        loss = None
        if closure is not None:
            loss = closure()

        gather_time = 0
        allgather_time = 0
        all_time = 0

        if self.adam_freeze_key is False:
            v_diff_buffer = 0.0

        if grads is None:
            grads_group = [None] * len(self.param_groups)
        # backward compatibility
        # assuming a list/generator of parameter means single group
        elif isinstance(grads, types.GeneratorType):
            grads_group = [grads]
        elif type(grads[0]) != list:
            grads_group = [grads]
        else:
            grads_group = grads

        for group, grads_this_group in zip(self.param_groups, grads_group):
            if grads_this_group is None:
                grads_this_group = [None] * len(group['params'])

            bias_correction = 1 if group['bias_correction'] else 0

            for p, grad in zip(group['params'], grads_this_group):
                if p.grad is None and grad is None:
                    continue
                if grad is None:
                    grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        '1-bit Adam does not support sparse gradients')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                if not self.initialize or (self.adam_freeze_key and
                                           'worker_error' not in state.keys()):
                    state['tensor_size'] = torch.numel(p.data)
                    state['corrected_tensor_size'] = state['tensor_size']

                    if state['tensor_size'] % (self.size * self.divider) != 0:
                        state['corrected_tensor_size'] += (
                            (self.size * self.divider) -
                            (state['tensor_size'] %
                             (self.size * self.divider)))
                    state['server_chunk_size'] = state[
                        'corrected_tensor_size'] // self.size
                    torch.cuda.empty_cache()
                    state['worker_error'] = torch.zeros(
                        state['corrected_tensor_size'], device=p.device)
                    state['server_error'] = torch.zeros(
                        state['server_chunk_size'], device=p.device)
                    torch.cuda.empty_cache()
                    self.adam_freeze_key = True
                    if not self.initialize and dist.get_rank() == 0:
                        print("Cupy Buffers Initialized Successfully.")

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                if self.adam_freeze_key is False:
                    exp_avg.mul_(beta1).add_(1 - beta1, grad)
                    exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                    grad = None
                    if self.initialize:
                        update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])

                else:
                    if 'non_freeze' in group.keys(
                    ) and group['non_freeze'] is True:
                        dist.all_reduce(grad)
                        grad.mul_(1 / dist.get_world_size())
                        exp_avg.mul_(beta1).add_(1 - beta1, grad)
                        exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                        grad = None
                    else:
                        if self.initialize is True:
                            exp_avg.mul_(beta1).add_(1 - beta1, grad)
                        grad = None

                        if self.size > 1:
                            exp_avg.set_(
                                self.comm_backend_handle.compressed_allreduce(
                                    exp_avg, state['worker_error'],
                                    state['server_error'],
                                    self.deepspeed.local_rank))
                        # Because 1-bit compression cannot represent exact zero, it is required to
                        # provide a momentum mask for those params that have constant exact zeros in their
                        # momentums, otherwise the compression error would keep accumulating.
                        # For example, for BERT pre-training seq 128, bert.embeddings.position_embeddings.weight
                        # always have exact zeros in its momentum for row 129 to 512, because it only
                        # learns up to seq length 128 while the model supports up to 512 seq length.
                        # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
                        if 'exp_avg_mask' in group:
                            if exp_avg.device != group['exp_avg_mask'].device:
                                group['exp_avg_mask'] = group[
                                    'exp_avg_mask'].to(device=exp_avg.device)
                            exp_avg.mul_(group['exp_avg_mask'])

                    if self.initialize:
                        update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])

                if self.initialize:
                    if group['weight_decay'] > 0.0:
                        update += group['weight_decay'] * p.data
                    with torch.no_grad():
                        p.add_(-group['lr'] * update)

            if not self.initialize:
                print('Pop out errors', flush=True)
                state.pop('worker_error')
                state.pop('server_error')

        if not self.initialize:
            self.adam_freeze_key = False
            self.initialize = True
            print(
                f"Finished the initialization step at rank {dist.get_rank()}")
            return loss

        if self.adam_freeze_key is False:
            if state['step'] >= self.freeze_step:
                print('OnebitAdam - starting compressed communication')
                self.adam_freeze_key = True
                if self.using_pipeline:
                    self.deepspeed.pipeline_enable_backward_allreduce = False
                else:
                    self.deepspeed.enable_backward_allreduce = False

        return loss
Пример #16
0
    def forward(ctx, input, input_mask, head_mask, layer_past, get_present,
                encoder_hidden_states, encoder_attention_mask,
                output_attentions, norm_w, norm_b, config, attn_qkvw,
                attn_qkvb, num_attention_heads_per_partition, norm_factor,
                hidden_size_per_partition, attn_ow, attn_ob, mp_group,
                q_scales, q_groups, merge_count, qkv_merging,
                score_context_func, alibi):
        def _transpose_for_scores(x, key=False, reshape=False):
            attention_head_size = x.shape[
                -1] // num_attention_heads_per_partition
            new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition,
                                           attention_head_size)
            x_1 = x.view(*new_x_shape)
            if key:
                x_1 = x_1.permute(0, 2, 3, 1)
            else:
                x_1 = x_1.permute(0, 2, 1, 3)
            if reshape:
                return x_1.reshape(x.shape)
            return x_1.contiguous()

        def _transpose_for_context(x):
            x = x.permute(0, 2, 1, 3).contiguous()
            new_x_layer_shape = x.size()[:-2] + \
                                      (hidden_size_per_partition,)
            return x.view(*new_x_layer_shape).contiguous()

        ########### This part is taken/modified form the HF modeling_bloom.py ################
        # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py

        def split_tensor_along_last_dim(tensor,
                                        num_partitions,
                                        contiguous_split_chunks=True):
            """Split a tensor along its last dimension.

            Args:
                tensor: ([`torch.tensor`], *required*):
                    input tensor to split
                num_partitions ([`int`], *required*):
                    number of partitions to split the tensor
                contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
                    If True, make each chunk contiguous in memory.
            """
            # Get the size and dimension.
            last_dim = tensor.dim() - 1
            numerator, denominator = tensor.size()[last_dim], num_partitions
            if not (numerator % denominator == 0):
                raise ValueError(
                    f"{numerator} is not divisible by {denominator}")
            last_dim_size = numerator // denominator
            # Split.
            tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
            # Note: torch.split does not create contiguous tensors by default.
            if contiguous_split_chunks:
                return tuple(chunk.contiguous() for chunk in tensor_list)

            return tensor_list

        def backup_attention(mixed_x_layer, layer_past, alibi, input_mask,
                             norm_factor):
            alibi = alibi.to(torch.cuda.current_device())
            head_dim = hidden_size_per_partition // num_attention_heads_per_partition
            new_tensor_shape = mixed_x_layer.size()[:-1] + (
                num_attention_heads_per_partition, 3 * head_dim)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            (query_layer, key_layer,
             value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

            if layer_past is not None:
                past_key, past_value = layer_past
                # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
                key_layer = torch.cat((past_key.type_as(key_layer), key_layer),
                                      dim=1)
                value_layer = torch.cat(
                    (past_value.type_as(value_layer), value_layer), dim=1)

            presents = (key_layer, value_layer)

            # [batch_size, head_dim, q_length, k_length]
            output_size = (query_layer.size(0), query_layer.size(2),
                           query_layer.size(1), key_layer.size(1))
            # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
            query_layer = query_layer.transpose(1, 0).reshape(
                output_size[2], output_size[0] * output_size[1], -1)
            # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
            key_layer = key_layer.transpose(1, 0).reshape(
                output_size[3], output_size[0] * output_size[1], -1)

            # Raw attention scores. [batch_size * num_heads, q_length, k_length]
            matmul_result = torch.matmul(
                query_layer.transpose(1, 0),
                key_layer.transpose(1, 0).transpose(1, 2))
            # change view to [batch_size, num_heads, q_length, k_length]
            attention_scores = matmul_result.view(*output_size)

            offset = dist.get_rank(
            ) * num_attention_heads_per_partition if dist.is_initialized(
            ) else 0
            attention_probs = inference_cuda_module.softmax_fp16(
                attention_scores,
                ((1 - input_mask).half() *
                 minus_inf) if input_mask.dtype == torch.int64 else input_mask,
                alibi, (config.triangular_masking and
                        (attention_scores.shape[-2] > 1)), False, False, 1,
                False, 1 / (norm_factor * norm_factor), offset, config.mp_size)
            # change view [batch_size x num_heads, q_length, k_length]
            attention_probs_reshaped = attention_probs.view(
                *matmul_result.shape)

            # matmul: [batch_size * num_heads, q_length, head_dim]
            context_layer = torch.bmm(
                attention_probs_reshaped,
                value_layer.transpose(1, 2).reshape(-1, value_layer.size(1),
                                                    value_layer.size(3)))

            # change view [batch_size, num_heads, q_length, head_dim]
            context_layer = context_layer.view(
                context_layer.size(0) // num_attention_heads_per_partition,
                num_attention_heads_per_partition, context_layer.size(1),
                context_layer.shape[-1])

            context_layer = _transpose_for_context(context_layer)

            return context_layer, presents

        ###################### End of HF modeling_bloom addition ########################

        def compute_attention(qkv_out, input_mask):
            no_masking = input_mask is None

            head_size = (qkv_out.shape[-1] // 3 //
                         num_attention_heads_per_partition)
            if no_masking:
                input_mask = torch.empty(1)
            if merge_count > 0 and config.q_int8:
                split_dim = (qkv_out.dim() - 1)
                qkv_split = torch.split(qkv_out, (qkv_out.shape[-1] //
                                                  (2**merge_count)),
                                        dim=split_dim)
                qkv_split = [
                    torch.split(s, (s.shape[-1] // 3), dim=split_dim)
                    for s in qkv_split
                ]
                (mixed_query, key_layer, value_layer) = [
                    torch.cat([s[i] for s in qkv_split], axis=-1)
                    for i in range(len(qkv_split[0]))
                ]

                if config.rotary_dim > 0:
                    mixed_query, key_layer = inference_cuda_module.apply_rotary_pos_emb(
                        mixed_query, key_layer, config.rotary_dim,
                        0 if layer_past is None else layer_past[0].shape[-2],
                        num_attention_heads_per_partition, config.rotate_half,
                        config.rotate_every_two)
                if layer_past is not None:
                    past_key, past_value = layer_past
                    key_layer = torch.cat(
                        (past_key.type_as(key_layer), key_layer), dim=-2)
                    value_layer = torch.cat(
                        (past_value.type_as(value_layer), value_layer), dim=-2)
                presents = (key_layer, value_layer)
                mixed_query = _transpose_for_scores(mixed_query, False, True)
                key_layer = _transpose_for_scores(key_layer, True, True) / (
                    norm_factor if config.scale_attention else 1.0)
                value_layer = _transpose_for_scores(value_layer, False, True)
                if layer_past is None:
                    attn_key_value = score_context_func(
                        mixed_query,
                        key_layer,
                        torch.empty(1),
                        ((1 - input_mask).half() * minus_inf)
                        if input_mask.dtype == torch.int64 else input_mask,
                        value_layer,
                        torch.empty(1),
                        num_attention_heads_per_partition,
                        (1 / norm_factor if config.scale_attention else 1.0),
                        (not unfused_mode),  # noqa: F821
                        config.triangular_masking,
                        config.local_attention,
                        config.window_size,
                        no_masking)
                else:
                    attn_key_value = score_context_func(
                        mixed_query,
                        (key_layer if unfused_mode else
                         past_key.type_as(key_layer)),  # noqa: F821
                        key_layer,
                        ((1 - input_mask).half() * minus_inf)
                        if input_mask.dtype == torch.int64 else input_mask,
                        (value_layer if unfused_mode else
                         past_value.type_as(value_layer)),  # noqa: F821
                        value_layer,
                        num_attention_heads_per_partition,
                        (1 / norm_factor if config.scale_attention else 1.0),
                        (not unfused_mode),  # noqa: F821
                        config.triangular_masking,
                        config.local_attention,
                        config.window_size,
                        no_masking)
                if unfused_mode:  # noqa: F821
                    context_layer, _, _ = attn_key_value
                else:
                    context_layer, key_layer, value_layer = attn_key_value

                # Transpose Context
                context_layer = _transpose_for_context(context_layer)

                return context_layer, presents[0], presents[
                    1]  # atten_output, key_layer, value_layer
            else:
                # Note: This modification is added for the BLOOM-176B model and will be removed later!
                if config.bigscience_bloom:
                    context_layer, presents = backup_attention(
                        qkv_out, layer_past, alibi, input_mask, norm_factor)
                    return context_layer, presents[0], presents[
                        1]  #key_layer, value_layer
                else:
                    if alibi is not None:
                        batch_heads = qkv_out.shape[
                            0] * num_attention_heads_per_partition
                        offset = dist.get_rank(
                        ) * batch_heads if dist.is_initialized() else 0
                        sliced_alibi = alibi[offset:batch_heads + offset, :, :]

                    attn_key_value = score_context_func(
                        qkv_out,
                        ((1 - input_mask).to(qkv_out.dype) * minus_inf)
                        if input_mask.dtype == torch.int64 else input_mask,
                        config.rotary_dim, config.rotate_half,
                        config.rotate_every_two,
                        num_attention_heads_per_partition,
                        (1 / norm_factor if config.scale_attention else 1.0),
                        config.triangular_masking, config.local_attention,
                        config.window_size, no_masking, config.layer_id,
                        DeepSpeedTransformerInference.layer_id,
                        sliced_alibi if alibi is not None else torch.empty(1))
                    context_layer, key_layer, value_layer = attn_key_value
                    return context_layer, key_layer, value_layer

        def selfAttention_fp():
            vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \
                                    inference_cuda_module.vector_matmul_fp32
            if not config.pre_layer_norm:
                linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \
                                    inference_cuda_module.linear_layer_fp32

                qkv_out = linear_func(input, attn_qkvw, attn_qkvb,
                                      DeepSpeedTransformerInference.layer_id)
            else:
                qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \
                                    inference_cuda_module.qkv_gemm_fp32
                qkv_out = qkv_func(
                    input, attn_qkvw,
                    (attn_qkvb if attn_qkvb is not None else norm_b), norm_w,
                    norm_b, config.epsilon, (attn_qkvb is not None),
                    1 if config.bigscience_bloom else
                    DeepSpeedTransformerInference.layer_id)
            context_layer, key_layer, value_layer = compute_attention(
                qkv_out[0] if isinstance(qkv_out, list) else qkv_out,
                input_mask)
            output = vector_matmul_func(context_layer, attn_ow, False)

            return output, key_layer, value_layer, context_layer, qkv_out[-1]

        def selfAttention_int8():
            if not config.pre_layer_norm:
                qkv_out = inference_cuda_module.linear_layer_int8(
                    input, attn_qkvw, attn_qkvb, q_scales[0],
                    (q_groups * (3 if qkv_merging else 1) * (2**merge_count)))

            else:
                qkv_out = inference_cuda_module.qkv_gemm_int8(
                    input, attn_qkvw, attn_qkvb, norm_w, norm_b,
                    config.epsilon, q_scales[0],
                    (q_groups * (3 if qkv_merging else 1) * (2**merge_count)),
                    (attn_qkvb is not None))
            context_layer, key_layer, value_layer = compute_attention(qkv_out)
            output = inference_cuda_module.vector_matmul_int8(
                context_layer, attn_ow, q_scales[1], q_groups, (merge_count))
            return output, key_layer, value_layer, context_layer

        if config.q_int8:
            output, key_layer, value_layer, context_layer = selfAttention_int8(
            )
        else:
            output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp(
            )
        if config.mlp_after_attn and mp_group is not None and dist.get_world_size(
                group=mp_group) > 1:
            dist.all_reduce(output, group=mp_group)

        return (output, key_layer, value_layer, context_layer, inp_norm)
Пример #17
0
 def allreduce_tied_weight_gradients(self):
     '''All reduce the gradients of the tied weights between tied stages'''
     for key, comm in self.tied_comms.items():
         weight = getattr(self.tied_modules[key], comm['weight_attr'])
         dist.all_reduce(weight.grad, group=comm['group'])
Пример #18
0
def top1gating(
        logits: Tensor,
        capacity_factor: float,
        min_capacity: int,
        used_token: Tensor = None,
        noisy_gate_policy: Optional[str] = None,
        drop_tokens: bool = True,
        use_rts: bool = True,
        use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """Implements Top1Gating on logits."""
    if noisy_gate_policy == 'RSample':
        logits_w_noise = logits + gumbel_rsample(logits.shape,
                                                 device=logits.device)
    # everything is in fp32 in this function
    gates = F.softmax(logits, dim=1)

    capacity = _capacity(gates, torch.tensor(capacity_factor),
                         torch.tensor(min_capacity))

    # Create a mask for 1st's expert per token
    # noisy gating
    indices1_s = torch.argmax(
        logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)
    num_experts = int(gates.shape[1])
    mask1 = F.one_hot(indices1_s, num_classes=num_experts)

    # mask only used tokens
    if used_token is not None:
        mask1 = einsum("s,se->se", used_token, mask1)

    # gating decisions
    exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')

    # if we don't want to drop any tokens
    if not drop_tokens:
        new_capacity = torch.max(exp_counts).to(logits.device)
        dist.all_reduce(new_capacity,
                        op=dist.ReduceOp.MAX,
                        group=dist.get_world_group())
        capacity = new_capacity

    # Compute l_aux
    me = torch.mean(gates, dim=0)
    ce = torch.mean(mask1.float(), dim=0)
    l_aux = torch.sum(me * ce) * num_experts

    # Random Token Selection
    if use_rts:
        uniform = exp_selection_uniform_map.get(logits.device)
        if uniform is None:
            uniform = torch.distributions.uniform.Uniform(
                low=torch.tensor(0.0, device=logits.device),
                high=torch.tensor(1.0, device=logits.device)).rsample
            exp_selection_uniform_map[logits.device] = uniform

        mask1_rand = mask1 * uniform(mask1.shape)
    else:
        mask1_rand = mask1

    assert logits.shape[
        0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."

    top_idx = _top_idx(mask1_rand, capacity)

    new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
    mask1 = new_mask1

    if use_tutel:
        # Tutel doesn't support index values masked with zero
        # so we need to replace masked indices with -1
        indices_mask = mask1.sum(dim=1) * num_experts - 1
        indices1_s = torch.min(indices1_s, indices_mask)

    # Compute locations in capacity buffer
    if use_tutel:
        locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
    else:
        locations1 = torch.cumsum(mask1, dim=0) - 1

    if use_tutel:
        gates1_s = (gates * mask1).sum(dim=1)
        locations1_s = torch.sum(locations1 * mask1, dim=1)
        return l_aux, capacity, num_experts, [
            indices1_s,
        ], [
            locations1_s,
        ], [
            gates1_s,
        ], exp_counts

    # Store the capacity location for each token
    locations1_s = torch.sum(locations1 * mask1, dim=1)

    # Normalize gate probabilities
    mask1_float = mask1.float()
    gates = gates * mask1_float

    locations1_sc = _one_hot_to_float(locations1_s, capacity)
    combine_weights = einsum("se,sc->sec", gates, locations1_sc)

    dispatch_mask = combine_weights.bool()

    return l_aux, combine_weights, dispatch_mask, exp_counts