def test(self): x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1) sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2 result = torch.ones(1, 3).cuda() * sum_of_ranks dist.all_reduce(x) assert torch.all(x == result)
def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123): with torch.random.fork_rng(devices=[torch.cuda.current_device()]): ds_utils.set_random_seed(seed) # disable dropout model.eval() trainset = cifar_trainset(fp16=fp16) config['local_rank'] = dist.get_rank() engine, _, _, _ = deepspeed.initialize( config=config, model=model, model_parameters=[p for p in model.parameters()], training_data=trainset) losses = [] for step in range(num_steps): loss = engine.train_batch() losses.append(loss.item()) if step % 50 == 0 and dist.get_rank() == 0: print(f'STEP={step} LOSS={loss.item()}') if average_dp_losses: loss_tensor = torch.tensor(losses).cuda() dist.all_reduce(loss_tensor) loss_tensor /= dist.get_world_size() losses = loss_tensor.tolist() return losses
def forward(self, input): output = torch.matmul(input, self.weight.transpose(-1, -2)) if self.mp_group is not None: dist.all_reduce(output, group=self.mp_group) if self.bias is not None: output += self.bias return output
def _reduce(input_): """All-reduce the the input tensor across model parallel group.""" group = g_mpu.get_model_parallel_group() # Bypass the function if we are using only 1 GPU. if dist.get_world_size(group=group) == 1: return input_ # All-reduce. dist.all_reduce(input_, group=group) return input_
def get_full_hp_param(self, optim_state_key=None): reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() if self._hp_mapping is not None: lp_frag_address = self._hp_mapping.lp_fragment_address reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel) if optim_state_key is None: hp_fragment = self._hp_mapping.hp_fragment else: hp_fragment = self._hp_mapping.get_optim_state_fragment( optim_state_key) reduce_fragment.data.copy_(hp_fragment.data) dist.all_reduce(reduce_buffer, group=self._dp_group) return reduce_buffer.reshape_as(self)
def forward(ctx, input, residual, residual_norm, bias, inter_w, inter_b, attn_nw, attn_nb, config, mp_group, output_b, output_w, q_scales, q_groups, merge_count, mlp_gemm_func, fused_gemm_gelu, vector_matmul_func, bias_residual_func, activation_func_type=ActivationFuncType.GELU): if config.q_int8: (intermediate, residual_add) = inference_cuda_module.mlp_gemm_int8( input, residual, bias, inter_w, inter_b, attn_nw, attn_nb, config.epsilon, q_scales[2], (q_groups * (2**merge_count)), config.pre_layer_norm) output = inference_cuda_module.vector_matmul_int8( intermediate, output_w, q_scales[3], q_groups, (merge_count)) else: if attn_nw is None: output = fused_gemm_gelu(residual_norm, inter_w, inter_b, output_w, config.epsilon, config.pre_layer_norm, False) else: intermediate, residual_add = mlp_gemm_func( input, residual, bias, inter_w, inter_b, attn_nw, attn_nb, config.epsilon, config.pre_layer_norm, config.mlp_after_attn, config.mlp_act_func_type) output = vector_matmul_func(intermediate, output_w, False) inference_cuda_module.residual_add( output, residual if config.pre_layer_norm else residual_add, input, output_b, bias if bias is not None else output_b, config.mp_size, config.mlp_after_attn, bias is not None, config.pre_layer_norm) if mp_group is not None and dist.get_world_size(group=mp_group) > 1: dist.all_reduce(output, group=mp_group) return output
def _get_norm_with_moe_layers(self, all_groups_norm): #all_groups_norm_old = all_groups_norm # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce if self.using_pipeline: pg = self.deepspeed.mpu.get_data_parallel_group() else: pg = groups._get_data_parallel_group() scaled_norm = all_groups_norm * 1.0 / float( dist.get_world_size(group=pg)) scaled_norm_tensor = torch.tensor( scaled_norm, device=self.fp32_groups_flat[0].device, dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=pg) all_groups_norm = scaled_norm_tensor.item() #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") return all_groups_norm
def forward(ctx, input, inter_w, inter_b, config, output_b, output_w, q_scales, q_groups, merge_count, mp_group, async_op): if config.q_int8: intermediate = inference_cuda_module.fused_gemm_gelu_int8( input, inter_w, inter_b, config.epsilon, q_scales[2], (q_groups * (2**merge_count)), config.pre_layer_norm) output = inference_cuda_module.vector_matmul_int8( intermediate, output_w, q_scales[3], q_groups, (merge_count)) else: mlp_gemm_func = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \ inference_cuda_module.fused_gemm_gelu_fp32 output = mlp_gemm_func(input, inter_w, inter_b, output_w, config.epsilon, config.pre_layer_norm, async_op) if mp_group is not None and dist.get_world_size(group=mp_group) > 1: dist.all_reduce(output, group=mp_group, async_op=async_op) return output + output_b
def has_overflow(self, params, has_moe_params=None): if has_moe_params is None: has_moe_params = self.has_moe_params overflow = self.has_overflow_serial(params) # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs overflow_gpu = torch.cuda.ByteTensor([overflow]) # deepspeeed.comm.all_reduce(overflow_gpu, # op=deepspeed.comm.ReduceOp.MAX, # group=mpu.get_model_parallel_group()) if has_moe_params: # All reduce this across expert_parallel_group, so that if an expert # overflows, we detect it here dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=groups._get_max_expert_parallel_group()) if self.zero_reduce_scatter: dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group()) elif self.mpu is not None: if self.deepspeed is not None: using_pipeline = hasattr(self.deepspeed, 'pipeline_enable_backward_allreduce') if (using_pipeline and self.deepspeed.pipeline_enable_backward_allreduce is False) or ( not using_pipeline and self.deepspeed.enable_backward_allreduce is False): dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group()) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group()) elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False: dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group()) overflow = overflow_gpu[0].item() return bool(overflow)
def torch_sim(a): a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) scale = a.norm() / np.sqrt(a.numel()) a_compressed = scale * a_sign a_sign = None worker_error = a - a_compressed dist.all_reduce(a_compressed) a_compressed.mul_(1 / dist.get_world_size()) a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_( 2.0) a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) server_scale = [ chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list ] a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) a_server_compressed = torch.cat([ server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size()) ]) rank = dist.get_rank() server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] torch.cuda.synchronize() dist.barrier() return a_server_compressed, worker_error, server_error
def test_grid_pipe_data(self): topo = Topo(axes=['pipe', 'data'], dims=[2, 2]) grid = Grid(topology=topo) assert grid._is_grid_valid() rank = dist.get_rank() assert grid.is_first_stage == (grid.get_stage_id() == 0) assert grid.is_last_stage == ( grid.get_stage_id() == grid.get_pipe_parallel_world_size() - 1) # Test collectives along the pipeline parallel process groups rank_tensor = torch.LongTensor(data=[rank]).cuda() dist.all_reduce(rank_tensor, group=grid.get_pipe_parallel_group()) pipe_group = grid.pp_group assert torch.all(rank_tensor == sum(pipe_group)) # Test collectives along the data parallel process groups rank_tensor = torch.LongTensor(data=[rank]).cuda() dist.all_reduce(rank_tensor, group=grid.get_data_parallel_group()) data_group = grid.dp_group assert torch.all(rank_tensor == sum(data_group))
def test(self, sequential_model, simple_config, batch_input): base_model = copy.deepcopy(sequential_model) base_input = batch_input.clone().detach() base_output = base_model(base_input) base_output = base_output base_params = sum(p.numel() for p in base_model.parameters()) pipe_model = copy.deepcopy(sequential_model) pipe_model = PipelineModule(layers=pipe_model, num_stages=2) # Ensure all parameters are accounted for. my_params = sum(p.numel() for p in pipe_model.parameters()) total_pipe_params = torch.LongTensor([my_params]).to('cuda') dist.all_reduce(total_pipe_params) total_pipe_params = total_pipe_params.item() assert total_pipe_params == base_params pipe_model, _, _, _ = deepspeed.initialize( config=simple_config, model=pipe_model, model_parameters=[p for p in pipe_model.parameters()]) if pipe_model.is_first_stage or pipe_model.is_last_stage: pipe_input = base_input.clone().detach().to('cuda') # label 0 is meaningless dataset = [(pipe_input, 0)] loader = RepeatingLoader(dataset) data_iter = iter(loader) else: data_iter = None pipe_output = pipe_model.eval_batch(data_iter=data_iter) base_output = base_output.to('cpu') pipe_output = pipe_output.to('cpu') assert torch.allclose(base_output, pipe_output, atol=1e-4)
def check_using_norm(self, norm_group, reduce_overflow=True): # TODO: I don't think reduce_overflow is needed if mpu is None overflow = -1 in norm_group overflow_gpu = torch.cuda.FloatTensor([overflow]) if self.has_moe_params: # In this case, we need to do an all_reduce across # the expert_parallel_group, so that if there was # an overflow due to expert weights, we detect it # Only need to check groups.get_largest_expert_parallel_group() dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=groups._get_max_expert_parallel_group()) if self.mpu is not None: dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group()) elif reduce_overflow: dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX) dist.barrier() overflow = overflow_gpu[0].item() return bool(overflow)
def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): """Clips gradient norm of an iterable of parameters. This has been adapted from Nvidia megatron. We add norm averaging to consider MoE params when calculating norm as they will result in different norms across different ranks. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Note that the gradients are modified in place. Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) max_norm = float(max_norm) norm_type = float(norm_type) if norm_type == inf: total_norm = max(p.grad.data.abs().max() for p in parameters) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all GPUs. if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: total_norm = 0 for p in parameters: if mpu is not None: if (mpu.get_model_parallel_rank() == 0) or is_model_parallel_parameter(p): param_norm = p.grad.data.norm(norm_type) total_norm += param_norm.item()**norm_type else: param_norm = p.grad.data.float().norm(norm_type) total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item()**(1. / norm_type) # Need to average total_norm across different GPUs due to the presence of moe params pg = groups._get_data_parallel_group() scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg)) scaled_norm_tensor = torch.cuda.FloatTensor([float(scaled_norm)]) dist.all_reduce(scaled_norm_tensor, group=pg) total_norm = scaled_norm_tensor.item() clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: p.grad.data.mul_(clip_coef) return total_norm
def step(self, closure=None, grads=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. grads (list of tensors, optional): weight gradient to use for the optimizer update. If gradients have type torch.half, parameters are expected to be in type torch.float. (default: None) output params (list of tensors, optional): A reduced precision copy of the updated weights written out in addition to the regular updated weights. Have to be of same type as gradients. (default: None) scale (float, optional): factor to divide gradient tensor values by before applying to weights. (default: 1) """ loss = None if closure is not None: loss = closure() gather_time = 0 allgather_time = 0 all_time = 0 if self.adam_freeze_key is False: v_diff_buffer = 0.0 if grads is None: grads_group = [None] * len(self.param_groups) # backward compatibility # assuming a list/generator of parameter means single group elif isinstance(grads, types.GeneratorType): grads_group = [grads] elif type(grads[0]) != list: grads_group = [grads] else: grads_group = grads for group, grads_this_group in zip(self.param_groups, grads_group): if grads_this_group is None: grads_this_group = [None] * len(group['params']) bias_correction = 1 if group['bias_correction'] else 0 for p, grad in zip(group['params'], grads_this_group): if p.grad is None and grad is None: continue if grad is None: grad = p.grad.data if grad.is_sparse: raise RuntimeError( '1-bit Adam does not support sparse gradients') state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p.data) if not self.initialize or (self.adam_freeze_key and 'worker_error' not in state.keys()): state['tensor_size'] = torch.numel(p.data) state['corrected_tensor_size'] = state['tensor_size'] if state['tensor_size'] % (self.size * self.divider) != 0: state['corrected_tensor_size'] += ( (self.size * self.divider) - (state['tensor_size'] % (self.size * self.divider))) state['server_chunk_size'] = state[ 'corrected_tensor_size'] // self.size torch.cuda.empty_cache() state['worker_error'] = torch.zeros( state['corrected_tensor_size'], device=p.device) state['server_error'] = torch.zeros( state['server_chunk_size'], device=p.device) torch.cuda.empty_cache() self.adam_freeze_key = True if not self.initialize and dist.get_rank() == 0: print("Cupy Buffers Initialized Successfully.") exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 if self.adam_freeze_key is False: exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) grad = None if self.initialize: update = exp_avg / (exp_avg_sq.sqrt() + group['eps']) else: if 'non_freeze' in group.keys( ) and group['non_freeze'] is True: dist.all_reduce(grad) grad.mul_(1 / dist.get_world_size()) exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) grad = None else: if self.initialize is True: exp_avg.mul_(beta1).add_(1 - beta1, grad) grad = None if self.size > 1: exp_avg.set_( self.comm_backend_handle.compressed_allreduce( exp_avg, state['worker_error'], state['server_error'], self.deepspeed.local_rank)) # Because 1-bit compression cannot represent exact zero, it is required to # provide a momentum mask for those params that have constant exact zeros in their # momentums, otherwise the compression error would keep accumulating. # For example, for BERT pre-training seq 128, bert.embeddings.position_embeddings.weight # always have exact zeros in its momentum for row 129 to 512, because it only # learns up to seq length 128 while the model supports up to 512 seq length. # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.) if 'exp_avg_mask' in group: if exp_avg.device != group['exp_avg_mask'].device: group['exp_avg_mask'] = group[ 'exp_avg_mask'].to(device=exp_avg.device) exp_avg.mul_(group['exp_avg_mask']) if self.initialize: update = exp_avg / (exp_avg_sq.sqrt() + group['eps']) if self.initialize: if group['weight_decay'] > 0.0: update += group['weight_decay'] * p.data with torch.no_grad(): p.add_(-group['lr'] * update) if not self.initialize: print('Pop out errors', flush=True) state.pop('worker_error') state.pop('server_error') if not self.initialize: self.adam_freeze_key = False self.initialize = True print( f"Finished the initialization step at rank {dist.get_rank()}") return loss if self.adam_freeze_key is False: if state['step'] >= self.freeze_step: print('OnebitAdam - starting compressed communication') self.adam_freeze_key = True if self.using_pipeline: self.deepspeed.pipeline_enable_backward_allreduce = False else: self.deepspeed.enable_backward_allreduce = False return loss
def forward(ctx, input, input_mask, head_mask, layer_past, get_present, encoder_hidden_states, encoder_attention_mask, output_attentions, norm_w, norm_b, config, attn_qkvw, attn_qkvb, num_attention_heads_per_partition, norm_factor, hidden_size_per_partition, attn_ow, attn_ob, mp_group, q_scales, q_groups, merge_count, qkv_merging, score_context_func, alibi): def _transpose_for_scores(x, key=False, reshape=False): attention_head_size = x.shape[ -1] // num_attention_heads_per_partition new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, attention_head_size) x_1 = x.view(*new_x_shape) if key: x_1 = x_1.permute(0, 2, 3, 1) else: x_1 = x_1.permute(0, 2, 1, 3) if reshape: return x_1.reshape(x.shape) return x_1.contiguous() def _transpose_for_context(x): x = x.permute(0, 2, 1, 3).contiguous() new_x_layer_shape = x.size()[:-2] + \ (hidden_size_per_partition,) return x.view(*new_x_layer_shape).contiguous() ########### This part is taken/modified form the HF modeling_bloom.py ################ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=True): """Split a tensor along its last dimension. Args: tensor: ([`torch.tensor`], *required*): input tensor to split num_partitions ([`int`], *required*): number of partitions to split the tensor contiguous_split_chunks ([`bool`], *optional*, default=`False`):: If True, make each chunk contiguous in memory. """ # Get the size and dimension. last_dim = tensor.dim() - 1 numerator, denominator = tensor.size()[last_dim], num_partitions if not (numerator % denominator == 0): raise ValueError( f"{numerator} is not divisible by {denominator}") last_dim_size = numerator // denominator # Split. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) # Note: torch.split does not create contiguous tensors by default. if contiguous_split_chunks: return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): alibi = alibi.to(torch.cuda.current_device()) head_dim = hidden_size_per_partition // num_attention_heads_per_partition new_tensor_shape = mixed_x_layer.size()[:-1] + ( num_attention_heads_per_partition, 3 * head_dim) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) if layer_past is not None: past_key, past_value = layer_past # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) value_layer = torch.cat( (past_value.type_as(value_layer), value_layer), dim=1) presents = (key_layer, value_layer) # [batch_size, head_dim, q_length, k_length] output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] query_layer = query_layer.transpose(1, 0).reshape( output_size[2], output_size[0] * output_size[1], -1) # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] key_layer = key_layer.transpose(1, 0).reshape( output_size[3], output_size[0] * output_size[1], -1) # Raw attention scores. [batch_size * num_heads, q_length, k_length] matmul_result = torch.matmul( query_layer.transpose(1, 0), key_layer.transpose(1, 0).transpose(1, 2)) # change view to [batch_size, num_heads, q_length, k_length] attention_scores = matmul_result.view(*output_size) offset = dist.get_rank( ) * num_attention_heads_per_partition if dist.is_initialized( ) else 0 attention_probs = inference_cuda_module.softmax_fp16( attention_scores, ((1 - input_mask).half() * minus_inf) if input_mask.dtype == torch.int64 else input_mask, alibi, (config.triangular_masking and (attention_scores.shape[-2] > 1)), False, False, 1, False, 1 / (norm_factor * norm_factor), offset, config.mp_size) # change view [batch_size x num_heads, q_length, k_length] attention_probs_reshaped = attention_probs.view( *matmul_result.shape) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm( attention_probs_reshaped, value_layer.transpose(1, 2).reshape(-1, value_layer.size(1), value_layer.size(3))) # change view [batch_size, num_heads, q_length, head_dim] context_layer = context_layer.view( context_layer.size(0) // num_attention_heads_per_partition, num_attention_heads_per_partition, context_layer.size(1), context_layer.shape[-1]) context_layer = _transpose_for_context(context_layer) return context_layer, presents ###################### End of HF modeling_bloom addition ######################## def compute_attention(qkv_out, input_mask): no_masking = input_mask is None head_size = (qkv_out.shape[-1] // 3 // num_attention_heads_per_partition) if no_masking: input_mask = torch.empty(1) if merge_count > 0 and config.q_int8: split_dim = (qkv_out.dim() - 1) qkv_split = torch.split(qkv_out, (qkv_out.shape[-1] // (2**merge_count)), dim=split_dim) qkv_split = [ torch.split(s, (s.shape[-1] // 3), dim=split_dim) for s in qkv_split ] (mixed_query, key_layer, value_layer) = [ torch.cat([s[i] for s in qkv_split], axis=-1) for i in range(len(qkv_split[0])) ] if config.rotary_dim > 0: mixed_query, key_layer = inference_cuda_module.apply_rotary_pos_emb( mixed_query, key_layer, config.rotary_dim, 0 if layer_past is None else layer_past[0].shape[-2], num_attention_heads_per_partition, config.rotate_half, config.rotate_every_two) if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat( (past_key.type_as(key_layer), key_layer), dim=-2) value_layer = torch.cat( (past_value.type_as(value_layer), value_layer), dim=-2) presents = (key_layer, value_layer) mixed_query = _transpose_for_scores(mixed_query, False, True) key_layer = _transpose_for_scores(key_layer, True, True) / ( norm_factor if config.scale_attention else 1.0) value_layer = _transpose_for_scores(value_layer, False, True) if layer_past is None: attn_key_value = score_context_func( mixed_query, key_layer, torch.empty(1), ((1 - input_mask).half() * minus_inf) if input_mask.dtype == torch.int64 else input_mask, value_layer, torch.empty(1), num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), (not unfused_mode), # noqa: F821 config.triangular_masking, config.local_attention, config.window_size, no_masking) else: attn_key_value = score_context_func( mixed_query, (key_layer if unfused_mode else past_key.type_as(key_layer)), # noqa: F821 key_layer, ((1 - input_mask).half() * minus_inf) if input_mask.dtype == torch.int64 else input_mask, (value_layer if unfused_mode else past_value.type_as(value_layer)), # noqa: F821 value_layer, num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), (not unfused_mode), # noqa: F821 config.triangular_masking, config.local_attention, config.window_size, no_masking) if unfused_mode: # noqa: F821 context_layer, _, _ = attn_key_value else: context_layer, key_layer, value_layer = attn_key_value # Transpose Context context_layer = _transpose_for_context(context_layer) return context_layer, presents[0], presents[ 1] # atten_output, key_layer, value_layer else: # Note: This modification is added for the BLOOM-176B model and will be removed later! if config.bigscience_bloom: context_layer, presents = backup_attention( qkv_out, layer_past, alibi, input_mask, norm_factor) return context_layer, presents[0], presents[ 1] #key_layer, value_layer else: if alibi is not None: batch_heads = qkv_out.shape[ 0] * num_attention_heads_per_partition offset = dist.get_rank( ) * batch_heads if dist.is_initialized() else 0 sliced_alibi = alibi[offset:batch_heads + offset, :, :] attn_key_value = score_context_func( qkv_out, ((1 - input_mask).to(qkv_out.dype) * minus_inf) if input_mask.dtype == torch.int64 else input_mask, config.rotary_dim, config.rotate_half, config.rotate_every_two, num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), config.triangular_masking, config.local_attention, config.window_size, no_masking, config.layer_id, DeepSpeedTransformerInference.layer_id, sliced_alibi if alibi is not None else torch.empty(1)) context_layer, key_layer, value_layer = attn_key_value return context_layer, key_layer, value_layer def selfAttention_fp(): vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \ inference_cuda_module.vector_matmul_fp32 if not config.pre_layer_norm: linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ inference_cuda_module.linear_layer_fp32 qkv_out = linear_func(input, attn_qkvw, attn_qkvb, DeepSpeedTransformerInference.layer_id) else: qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \ inference_cuda_module.qkv_gemm_fp32 qkv_out = qkv_func( input, attn_qkvw, (attn_qkvb if attn_qkvb is not None else norm_b), norm_w, norm_b, config.epsilon, (attn_qkvb is not None), 1 if config.bigscience_bloom else DeepSpeedTransformerInference.layer_id) context_layer, key_layer, value_layer = compute_attention( qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask) output = vector_matmul_func(context_layer, attn_ow, False) return output, key_layer, value_layer, context_layer, qkv_out[-1] def selfAttention_int8(): if not config.pre_layer_norm: qkv_out = inference_cuda_module.linear_layer_int8( input, attn_qkvw, attn_qkvb, q_scales[0], (q_groups * (3 if qkv_merging else 1) * (2**merge_count))) else: qkv_out = inference_cuda_module.qkv_gemm_int8( input, attn_qkvw, attn_qkvb, norm_w, norm_b, config.epsilon, q_scales[0], (q_groups * (3 if qkv_merging else 1) * (2**merge_count)), (attn_qkvb is not None)) context_layer, key_layer, value_layer = compute_attention(qkv_out) output = inference_cuda_module.vector_matmul_int8( context_layer, attn_ow, q_scales[1], q_groups, (merge_count)) return output, key_layer, value_layer, context_layer if config.q_int8: output, key_layer, value_layer, context_layer = selfAttention_int8( ) else: output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp( ) if config.mlp_after_attn and mp_group is not None and dist.get_world_size( group=mp_group) > 1: dist.all_reduce(output, group=mp_group) return (output, key_layer, value_layer, context_layer, inp_norm)
def allreduce_tied_weight_gradients(self): '''All reduce the gradients of the tied weights between tied stages''' for key, comm in self.tied_comms.items(): weight = getattr(self.tied_modules[key], comm['weight_attr']) dist.all_reduce(weight.grad, group=comm['group'])
def top1gating( logits: Tensor, capacity_factor: float, min_capacity: int, used_token: Tensor = None, noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top1Gating on logits.""" if noisy_gate_policy == 'RSample': logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) # everything is in fp32 in this function gates = F.softmax(logits, dim=1) capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity)) # Create a mask for 1st's expert per token # noisy gating indices1_s = torch.argmax( logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1) num_experts = int(gates.shape[1]) mask1 = F.one_hot(indices1_s, num_classes=num_experts) # mask only used tokens if used_token is not None: mask1 = einsum("s,se->se", used_token, mask1) # gating decisions exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') # if we don't want to drop any tokens if not drop_tokens: new_capacity = torch.max(exp_counts).to(logits.device) dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) capacity = new_capacity # Compute l_aux me = torch.mean(gates, dim=0) ce = torch.mean(mask1.float(), dim=0) l_aux = torch.sum(me * ce) * num_experts # Random Token Selection if use_rts: uniform = exp_selection_uniform_map.get(logits.device) if uniform is None: uniform = torch.distributions.uniform.Uniform( low=torch.tensor(0.0, device=logits.device), high=torch.tensor(1.0, device=logits.device)).rsample exp_selection_uniform_map[logits.device] = uniform mask1_rand = mask1 * uniform(mask1.shape) else: mask1_rand = mask1 assert logits.shape[ 0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." top_idx = _top_idx(mask1_rand, capacity) new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) mask1 = new_mask1 if use_tutel: # Tutel doesn't support index values masked with zero # so we need to replace masked indices with -1 indices_mask = mask1.sum(dim=1) * num_experts - 1 indices1_s = torch.min(indices1_s, indices_mask) # Compute locations in capacity buffer if use_tutel: locations1 = tutel_moe.fast_cumsum_sub_one(mask1) else: locations1 = torch.cumsum(mask1, dim=0) - 1 if use_tutel: gates1_s = (gates * mask1).sum(dim=1) locations1_s = torch.sum(locations1 * mask1, dim=1) return l_aux, capacity, num_experts, [ indices1_s, ], [ locations1_s, ], [ gates1_s, ], exp_counts # Store the capacity location for each token locations1_s = torch.sum(locations1 * mask1, dim=1) # Normalize gate probabilities mask1_float = mask1.float() gates = gates * mask1_float locations1_sc = _one_hot_to_float(locations1_s, capacity) combine_weights = einsum("se,sc->sec", gates, locations1_sc) dispatch_mask = combine_weights.bool() return l_aux, combine_weights, dispatch_mask, exp_counts