def __init__(self) -> None: super().__init__() self.modulelist = ModuleList( EltwiseMultiplicationModule(weight=Parameter( torch.empty((param_sz, ), dtype=torch.float32))) for _ in range(n_layers)) for layer_num, module in enumerate(self.modulelist): with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): param: Parameter = module.weight partition_sz = math.ceil(param.numel() / dist.get_world_size()) offset = 0 for rank in range(dist.get_world_size()): with torch.no_grad(): param[offset:offset + partition_sz].fill_( 2 * layer_num * rank) offset += partition_sz
def _gather(input_): """Gather tensors and concatinate along the last dimension.""" group = g_mpu.get_model_parallel_group() # Bypass the function if we are using only 1 GPU. if dist.get_world_size(group=group) == 1: return input_ # Size and dimension. last_dim = input_.dim() - 1 rank = dist.get_rank(group=group) world_size = dist.get_world_size(group=group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ dist.all_gather(tensor_list, input_, group=group) # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=last_dim).contiguous() return output
def _reduce(input_): """All-reduce the the input tensor across model parallel group.""" group = g_mpu.get_model_parallel_group() # Bypass the function if we are using only 1 GPU. if dist.get_world_size(group=group) == 1: return input_ # All-reduce. dist.all_reduce(input_, group=group) return input_
def _create_ep_parallel_group(self, moe_experts): # Call the init process self.ep_group = {} self.expert_mp_group = {} moe_experts = moe_experts if type(moe_experts) is list else [moe_experts] for e in moe_experts: self.ep_group.update({e: None}) self.expert_mp_group.update({e: None}) for moe_ep_size in self.ep_group.keys(): num_ep_groups = dist.get_world_size() // moe_ep_size for i in range(num_ep_groups): ep_cnt = i * moe_ep_size size = dist.get_world_size( ) if moe_ep_size > dist.get_world_size() else moe_ep_size ranks = list(range(ep_cnt, ep_cnt + size)) _ep_group = dist.new_group(ranks) if dist.get_rank() in ranks: self.ep_group.update({moe_ep_size: _ep_group}) if dist.get_world_size() > moe_ep_size: num_expert_mp_groups = dist.get_world_size() // num_ep_groups expert_mp_size = dist.get_world_size() // moe_ep_size for i in range(num_expert_mp_groups): expert_mp_comm_ranks = [ i + nr * moe_ep_size for nr in range(expert_mp_size) ] _expert_mp_group = dist.new_group(expert_mp_comm_ranks) if dist.get_rank() in expert_mp_comm_ranks: self.expert_mp_group.update({moe_ep_size: _expert_mp_group})
def _create_expert_and_data_parallel(expert_parallel_size_): """ Create expert and data parallel groups. Note: Caller of this function is responsible to check if the groups already exist. Example - E + D parallel world_size = 16 expert_parallel_size = 2 # number of experts in same group expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE """ assert dist.is_initialized() log_dist( f'Creating expert and data parallel groups with size {expert_parallel_size_}', ranks=[0]) world_size = dist.get_world_size() rank = dist.get_rank() _ensure_divisibility(world_size, expert_parallel_size_) group_name = f"ep_size_{expert_parallel_size_}" # Build the expert data parallel groups. global _EXPERT_DATA_PARALLEL_GROUP # Only create group if it does not already exist if group_name not in _EXPERT_DATA_PARALLEL_GROUP: for i in range(expert_parallel_size_): ranks = range(i, world_size, expert_parallel_size_) group = dist.new_group(ranks) log_dist( f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}', [0]) if i == (rank % expert_parallel_size_): _EXPERT_DATA_PARALLEL_GROUP[group_name] = group # Build the expert parallel groups. global _EXPERT_PARALLEL_GROUP # Only create group if it does not already exist if group_name not in _EXPERT_PARALLEL_GROUP: for i in range(world_size // expert_parallel_size_): ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_) group = dist.new_group(ranks) log_dist( f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0]) if i == (rank // expert_parallel_size_): _EXPERT_PARALLEL_GROUP[group_name] = group
def _clone_world_group(): """Create a clone of the world group Note: We need to clone the dist world group because we use dist.get_global_rank() utility function in DeepSpeed at many places. As that function does not work on dist.group.WORLD, we need to keep a clone of it. """ assert dist.is_initialized(), "dist is not initialized" global _WORLD_GROUP if _WORLD_GROUP is None: # If not cloned already, clone the world group _WORLD_GROUP = dist.new_group(ranks=range(dist.get_world_size())) return _WORLD_GROUP
def _test_batch_config(num_ranks, batch, micro_batch, gas, success): assert dist.get_world_size() == num_ranks, \ 'The test assumes a world size of f{num_ranks}' ds_batch_config = get_test_path('ds_batch_config.json') ds_config = DeepSpeedConfig(ds_batch_config) #test cases when all parameters are provided status = _run_batch_config(ds_config, train_batch=batch, micro_batch=micro_batch, gas=gas) _batch_assert(status, ds_config, batch, micro_batch, gas, success) #test cases when two out of three parameters are provided status = _run_batch_config(ds_config, train_batch=batch, micro_batch=micro_batch) _batch_assert(status, ds_config, batch, micro_batch, gas, success) if success: #when gas is provided with one more parameter status = _run_batch_config(ds_config, train_batch=batch, gas=gas) _batch_assert(status, ds_config, batch, micro_batch, gas, success) status = _run_batch_config(ds_config, micro_batch=micro_batch, gas=gas) _batch_assert(status, ds_config, batch, micro_batch, gas, success) #test the case when only micro_batch or train_batch is provided if gas == 1: status = _run_batch_config(ds_config, micro_batch=micro_batch) _batch_assert(status, ds_config, batch, micro_batch, gas, success) status = _run_batch_config(ds_config, train_batch=batch) _batch_assert(status, ds_config, batch, micro_batch, gas, success) else: #when only gas is provided status = _run_batch_config(ds_config, gas=gas) _batch_assert(status, ds_config, batch, micro_batch, gas, success) #when gas is provided with something else and gas does not divide batch if gas != 1: status = _run_batch_config(ds_config, train_batch=batch, gas=gas) _batch_assert(status, ds_config, batch, micro_batch, gas, success)
def torch_sim(a): a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) scale = a.norm() / np.sqrt(a.numel()) a_compressed = scale * a_sign a_sign = None worker_error = a - a_compressed dist.all_reduce(a_compressed) a_compressed.mul_(1 / dist.get_world_size()) a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_( 2.0) a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) server_scale = [ chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list ] a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) a_server_compressed = torch.cat([ server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size()) ]) rank = dist.get_rank() server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] torch.cuda.synchronize() dist.barrier() return a_server_compressed, worker_error, server_error
def _create_model_parallel(model_parallel_size_): """ Initialize model data parallel groups. Arguments: model_parallel_size: number of GPUs used to parallelize model. Returns: Tuple of data parallel group and model parallel group Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model. The present function will create 4 model parallel groups and 2 data parallel groups as: 4 model parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7] 2 data parallel groups: [g0, g2, g4, g6], [g1, g3, g5, g7] Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ log_dist(f'Creating model parallel group with size {model_parallel_size_}', ranks=[0]) # Get world size and rank. Ensure some consistencies. assert dist.is_initialized() world_size = dist.get_world_size() model_parallel_size = min(model_parallel_size_, world_size) _ensure_divisibility(world_size, model_parallel_size) rank = dist.get_rank() _DATA_PARALLEL_GROUP = None _MODEL_PARALLEL_GROUP = None # Build the data parallel groups. for i in range(model_parallel_size): ranks = range(i, world_size, model_parallel_size) group = dist.new_group(ranks) if i == (rank % model_parallel_size): _DATA_PARALLEL_GROUP = group # Build the model parallel groups. for i in range(world_size // model_parallel_size): ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) group = dist.new_group(ranks) if i == (rank // model_parallel_size): _MODEL_PARALLEL_GROUP = group return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP
def forward(ctx, input, residual, residual_norm, bias, inter_w, inter_b, attn_nw, attn_nb, config, mp_group, output_b, output_w, q_scales, q_groups, merge_count, mlp_gemm_func, fused_gemm_gelu, vector_matmul_func, bias_residual_func, activation_func_type=ActivationFuncType.GELU): if config.q_int8: (intermediate, residual_add) = inference_cuda_module.mlp_gemm_int8( input, residual, bias, inter_w, inter_b, attn_nw, attn_nb, config.epsilon, q_scales[2], (q_groups * (2**merge_count)), config.pre_layer_norm) output = inference_cuda_module.vector_matmul_int8( intermediate, output_w, q_scales[3], q_groups, (merge_count)) else: if attn_nw is None: output = fused_gemm_gelu(residual_norm, inter_w, inter_b, output_w, config.epsilon, config.pre_layer_norm, False) else: intermediate, residual_add = mlp_gemm_func( input, residual, bias, inter_w, inter_b, attn_nw, attn_nb, config.epsilon, config.pre_layer_norm, config.mlp_after_attn, config.mlp_act_func_type) output = vector_matmul_func(intermediate, output_w, False) inference_cuda_module.residual_add( output, residual if config.pre_layer_norm else residual_add, input, output_b, bias if bias is not None else output_b, config.mp_size, config.mlp_after_attn, bias is not None, config.pre_layer_norm) if mp_group is not None and dist.get_world_size(group=mp_group) > 1: dist.all_reduce(output, group=mp_group) return output
def _initialize_parameter_parallel_groups(parameter_parallel_size=None): data_parallel_size = int(dist.get_world_size()) parameter_parallel_size = parameter_parallel_size or data_parallel_size logger.info("data_parallel_size: %s, parameter_parallel_size: %s", data_parallel_size, parameter_parallel_size) assert data_parallel_size % parameter_parallel_size == 0, \ 'world size should be divisible by parameter parallel size' rank = dist.get_rank() my_group = None for i in range(data_parallel_size // parameter_parallel_size): ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size) group = dist.new_group(ranks) if rank in ranks: my_group = group return my_group
def test_partitioned_tensor_meta(): world = dist.get_world_size() rank = dist.get_rank() group = dist.new_group(ranks=list(range(world))) rows = world * 7 cols = 3 full = torch.rand(rows, cols).cuda() dist.broadcast(full, src=0, group=group) part = PartitionedTensor(full, group=group) my_meta = PartitionedTensor.from_meta(part.to_meta(), part.local_data, group) assert torch.equal(full, my_meta.full())
def _get_norm_with_moe_layers(self, all_groups_norm): #all_groups_norm_old = all_groups_norm # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce if self.using_pipeline: pg = self.deepspeed.mpu.get_data_parallel_group() else: pg = groups._get_data_parallel_group() scaled_norm = all_groups_norm * 1.0 / float( dist.get_world_size(group=pg)) scaled_norm_tensor = torch.tensor( scaled_norm, device=self.fp32_groups_flat[0].device, dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=pg) all_groups_norm = scaled_norm_tensor.item() #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") return all_groups_norm
def test_partitioned_tensor(): world = dist.get_world_size() rank = dist.get_rank() group = dist.new_group(ranks=list(range(world))) rows = world * 4 cols = 3 full = torch.rand(rows, cols).cuda() dist.broadcast(full, src=0, group=group) part = PartitionedTensor(full, group=group) assert len(part.local_size()) == 1 assert part.local_size()[0] * world == full.numel() reconstructed = part.full() assert torch.equal(full, reconstructed)
def forward(ctx, input, inter_w, inter_b, config, output_b, output_w, q_scales, q_groups, merge_count, mp_group, async_op): if config.q_int8: intermediate = inference_cuda_module.fused_gemm_gelu_int8( input, inter_w, inter_b, config.epsilon, q_scales[2], (q_groups * (2**merge_count)), config.pre_layer_norm) output = inference_cuda_module.vector_matmul_int8( intermediate, output_w, q_scales[3], q_groups, (merge_count)) else: mlp_gemm_func = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \ inference_cuda_module.fused_gemm_gelu_fp32 output = mlp_gemm_func(input, inter_w, inter_b, output_w, config.epsilon, config.pre_layer_norm, async_op) if mp_group is not None and dist.get_world_size(group=mp_group) > 1: dist.all_reduce(output, group=mp_group, async_op=async_op) return output + output_b
def test(self): param1 = torch.nn.Parameter(torch.Tensor([0])) param1.grad = torch.Tensor([1]) param2 = torch.nn.Parameter(torch.Tensor([0])) param2.grad = torch.Tensor([dist.get_rank() + 1]) # param2 is now MoE parameter param2.allreduce = False parameters = [param1, param2] groups._create_expert_and_data_parallel(2) norm = ds_utils.clip_grad_norm_(parameters, max_norm=0.1) norm = torch.Tensor([norm]).to(dist.get_rank()) world_size = dist.get_world_size() gathered_norm = [torch.zeros(1).cuda() for i in range(world_size)] dist.all_gather(gathered_norm, norm) assert gathered_norm[0] == gathered_norm[ 1], "norm at rank 0 does not match the norm at rank 1"
def from_meta(cls, meta, local_part, group, device='cuda'): assert meta.dtype == torch.long dummy = torch.ones(dist.get_world_size(group=group)) part_obj = cls(tensor=dummy, group=group) meta = meta.tolist() # [N, list0, ..., listN-1] part_obj.orig_size = meta[1:(1 + meta[0])] meta = meta[1 + meta[0]:] part_obj.orig_device = device part_obj.local_data = local_part.detach() part_obj.group = group # Partition is encoded like the rowptr of a CSR matrix: # [num_parts, rank, 0, part_1, ..., part_num_parts] # TODO: support shuffle between different partition granularities assert part_obj.num_parts == meta[0] assert part_obj.rank == meta[1] part_obj.partition = meta[2:] # length num_parts+1 return part_obj
def __init__(self, topology=None, process_group=None): # TODO use process_group if provided self.global_rank = dist.get_rank() self.world_size = dist.get_world_size() if topology is not None: if self.global_rank == 0: print('Using topology:', topology) self._topo = topology else: num_pp = 1 num_dp = 1 for idx, prime in enumerate(_prime_factors(self.world_size)): if idx % 2 == 0: num_pp *= prime else: num_dp *= prime self._topo = PipeDataParallelTopology(num_dp=num_dp, num_pp=num_pp) self.data_parallel_size = max(self._topo.get_dim('data'), 1) self.pipe_parallel_size = max(self._topo.get_dim('pipe'), 1) self.model_parallel_size = max(self._topo.get_dim('model'), 1) self.slice_parallel_size = self.model_parallel_size assert self._is_grid_valid(), "Invalid Grid" self.stage_id = self.get_stage_id() self.data_parallel_id = self.get_data_parallel_id() # Create new ProcessGroups for all model parallelism. DeepSpeedLight uses these # to detect overflow, etc. self.ds_model_proc_group = None self.ds_model_rank = -1 for dp in range(self.data_parallel_size): ranks = sorted(self._topo.get_axis_list(axis='data', idx=dp)) if self.global_rank == 0: #print(f'RANK={self.global_rank} building DeepSpeed model group: {ranks}') pass proc_group = dist.new_group(ranks=ranks) if self.global_rank in ranks: self.ds_model_proc_group = proc_group self.ds_model_world_size = len(ranks) self.ds_model_rank = ranks.index(self.global_rank) assert self.ds_model_rank > -1 assert self.ds_model_proc_group is not None # Create new ProcessGroup for gradient all-reduces - these are the data parallel groups self.dp_group = [] self.dp_groups = self._topo.get_axis_comm_lists('data') for g in self.dp_groups: proc_group = dist.new_group(ranks=g) if self.global_rank in g: self.dp_group = g self.dp_proc_group = proc_group self.is_first_stage = (self.stage_id == 0) self.is_last_stage = (self.stage_id == (self.pipe_parallel_size - 1)) self.p2p_groups = self._build_p2p_groups() # Create new ProcessGroup for pipeline collectives - these are pipe parallel groups self.pp_group = [] self.pp_proc_group = None self.pipe_groups = self._topo.get_axis_comm_lists('pipe') for ranks in self.pipe_groups: if self.global_rank == 0: #print(f'RANK={self.global_rank} building pipeline group: {ranks}') pass proc_group = dist.new_group(ranks=ranks) if self.global_rank in ranks: self.pp_group = ranks self.pp_proc_group = proc_group assert self.pp_proc_group is not None # Create new ProcessGroup for model (tensor-slicing) collectives # Short circuit case without model parallelism. # TODO: it would be nice if topology had bcast semantics to avoid this branching # case? if self.model_parallel_size == 1: for group_rank in range(self.world_size): group_rank = [group_rank] group = dist.new_group(ranks=group_rank) if group_rank[0] == self.global_rank: self.slice_group = group_rank self.slice_proc_group = group return else: self.mp_group = [] self.model_groups = self._topo.get_axis_comm_lists('model') for g in self.model_groups: proc_group = dist.new_group(ranks=g) if self.global_rank in g: self.slice_group = g self.slice_proc_group = proc_group
def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): """Clips gradient norm of an iterable of parameters. This has been adapted from Nvidia megatron. We add norm averaging to consider MoE params when calculating norm as they will result in different norms across different ranks. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Note that the gradients are modified in place. Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) max_norm = float(max_norm) norm_type = float(norm_type) if norm_type == inf: total_norm = max(p.grad.data.abs().max() for p in parameters) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all GPUs. if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: total_norm = 0 for p in parameters: if mpu is not None: if (mpu.get_model_parallel_rank() == 0) or is_model_parallel_parameter(p): param_norm = p.grad.data.norm(norm_type) total_norm += param_norm.item()**norm_type else: param_norm = p.grad.data.float().norm(norm_type) total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item()**(1. / norm_type) # Need to average total_norm across different GPUs due to the presence of moe params pg = groups._get_data_parallel_group() scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg)) scaled_norm_tensor = torch.cuda.FloatTensor([float(scaled_norm)]) dist.all_reduce(scaled_norm_tensor, group=pg) total_norm = scaled_norm_tensor.item() clip_coef = max_norm / (total_norm + 1e-6) if clip_coef < 1: for p in parameters: p.grad.data.mul_(clip_coef) return total_norm
def reduce_scatter_coalesced( tensors: List[Tensor], group: ProcessGroup = None, ) -> List[Tensor]: """simultaneously reduce-scatter a list of tensors - this can be done more efficiently than individual reduce scatter calls TODO. see if PyTorch team wants a c++ version of this for ProcessGroupNCCL """ this_rank = dist.get_rank(group) world_sz = dist.get_world_size(group) partition_lst_for_each_tensor = [None] * len(tensors) for tensor_idx, tensor in enumerate(tensors): flattened_tensor = tensor.view(-1) chunk_sz = math.ceil(tensor.numel() / world_sz) partition_lst_for_each_tensor[tensor_idx] = [ flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz] for rank in range(0, world_sz) ] padded_partition_sz_for_each_tensor = tuple( math.ceil(t.numel() / world_sz) for t in tensors) if len(tensors) == 1 and tensors[0].numel() % world_sz == 0: # if there's only one tensor being reduced and we don't need to pad # we have an opportunity to avoid a memory allocation tensor_partition_flat_buffer = tensors[0].view(-1) else: # interleave tensor partitions such that the correct reduced partitions of each tensor # end up at each rank tensor_partitions_lst_with_padding = [] for rank in range(world_sz): for tensor_idx in range(len(tensors)): # add tensor content tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank] tensor_partitions_lst_with_padding.append(tensor_chunk) # add padding if necessary padding_sz = padded_partition_sz_for_each_tensor[ tensor_idx] - tensor_chunk.numel() if padding_sz > 0: tensor_partitions_lst_with_padding.append( torch.empty(padding_sz, dtype=tensor_chunk.dtype, device=tensor_chunk.device)) tensor_partition_flat_buffer = instrument_w_nvtx( torch.cat)(tensor_partitions_lst_with_padding) tensor_partition_flat_buffer.div_(world_sz) # pre-divide tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk( tensor_partition_flat_buffer, world_sz) # batched reduce-scatter call _torch_reduce_scatter_fn(tensor_partition_flat_buffer, tensor_partition_buffer_for_each_rank[this_rank], group=group) # reverse procedure of the interleaving done previously, done on the # result of the batched reduce-scatter output_lst: List[Tensor] = [None] * len(tensors) offset = 0 for tensor_idx in range(len(tensors)): output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[ this_rank].narrow( 0, offset, partition_lst_for_each_tensor[tensor_idx][this_rank].numel()) offset += padded_partition_sz_for_each_tensor[tensor_idx] return output_lst
def replace_with_policy(child, policy_cls, triangular_masking, inference=False, layer_id=0): policy = policy_cls(child, inference=inference) if inference: hidden_size, num_attention_heads = policy.get_hidden_heads() assert num_attention_heads % mp_size == 0,\ "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\ "This is because the attention computation is partitioned evenly among the parallel GPUs." from deepspeed.moe.layer import MoE moe = False if hasattr(child, 'mlp') and isinstance(child.mlp, MoE): num_experts = child.mlp.num_experts moe = True attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention() if not moe or moe_type == 'standard': mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp() else: mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b, \ _res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type) attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm() if quantize: if policy_cls is not HFBertLayerPolicy: qkvw = qkvw.to(torch.int8) dense_w = dense_w.to(torch.int8) _h4h_w = [moe_w1.to(torch.int8) for moe_w1 in _h4h_w] if moe else _h4h_w.to(torch.int8) _4hh_w = [moe_w1.to(torch.int8) for moe_w1 in _4hh_w] if moe else _4hh_w.to(torch.int8) elif fp16: qkvw = qkvw.half() dense_w = dense_w.half() _h4h_w = [moe_w1.half() for moe_w1 in _h4h_w] if moe else _h4h_w.half() _4hh_w = [moe_w1.half() for moe_w1 in _4hh_w] if moe else _4hh_w.half() if quantize or fp16: qkvb = qkvb if qkvb is None else qkvb.half() dense_b = dense_b if dense_b is None else dense_b.half() _h4h_b = [moe_b1.half() for moe_b1 in _h4h_b] if moe else _h4h_b.half() _4hh_b = [moe_b1.half() for moe_b1 in _4hh_b] if moe else _4hh_b.half() attn_nw = attn_nw if attn_nw is None else attn_nw.half() attn_nb = attn_nb if attn_nb is None else attn_nb.half() input_nw = input_nw.half() input_nb = input_nb.half() if moe and moe_type == 'residual' and fp16: _res_h4h_b = _res_h4h_b.half() _res_4hh_b = _res_4hh_b.half() _res_h4h_w = _res_h4h_w.half() _res_4hh_w = _res_4hh_w.half() _res_coef = _res_coef.half() #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group) if inference: if moe: ep_world_size = dist.get_world_size() local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else 1e-12, fp16=fp16, pre_layer_norm=policy.pre_attn_norm, mp_size=mp_size, q_int8=quantize, moe_experts=local_ep_size, global_experts=num_experts, mlp_type=moe_type) else: rotary_dim = config.rotary_dim if hasattr(config, 'rotary_dim') else child.attention.rotary_ndims \ if hasattr(child, 'attention') and hasattr(child.attention,'rotary_ndims') else -1 bigscience_bloom = policy_cls is BLOOMLayerPolicy transformer_config = transformer_inference.DeepSpeedInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else (config.layer_norm_epsilon if hasattr(config, 'layer_norm_epsilon') else config.layernorm_epsilon if hasattr(config, 'layernorm_epsilon') else 1.0e-12), fp16=fp16, pre_layer_norm=policy.pre_attn_norm, mp_size=mp_size, q_int8=quantize, return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)), triangular_masking=(policy_cls is not HFBertLayerPolicy), local_attention=((config.attention_layers[layer_id] == "local") if hasattr(config, 'attention_layers') else False), window_size=(config.window_size if hasattr(config, 'window_size') else 1), rotary_dim=rotary_dim, mlp_after_attn=(rotary_dim is None or rotary_dim < 0), mlp_act_func_type=policy.mlp_act_func_type, training_mp_size=training_mp_size, bigscience_bloom=bigscience_bloom) if quantize and quantize_settings is not None: (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups) = quantize_settings if moe: new_module = transformer_inference.DeepSpeedMoEInference( transformer_config, mp_group=mp_group, ep_group=None if ep_group is None else ep_group[num_experts], expert_mp_group=None if expert_mp_group is None else expert_mp_group[num_experts], quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) else: new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) if quantize and qkvw.dtype != torch.int8: quantize_bits = 8 quantizer = WeightQuantization() if policy_cls is HFBertLayerPolicy: data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3) else: data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups) qkvw.data.copy_(data_quantized) qkvw.data = qkvw.data.to(torch.int8) else: if moe: new_module = transformer_inference.DeepSpeedMoEInference( transformer_config, mp_group=mp_group, ep_group=None if ep_group is None else ep_group[num_experts], expert_mp_group=None if expert_mp_group is None else expert_mp_group[num_experts], ) else: new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, ) new_module.config.scale_attention = scale_attention # we want the weights in [input, output] shape # linear layer is created with [input, output] shape # transpose it here to reduce inference cost! def transpose(data): # temp move to cpu to avoid requiring extra GPU memory during the reshape data = data.to('cpu') data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1)) data = data.reshape(data.shape[-1], data.shape[-2]) data.to(torch.cuda.current_device()) return data attn_block = new_module.attention mpl_block = new_module.mlp if attn_linear_layer: if qkvw.numel() == 0 or qkvw.is_meta: if qkvw.is_meta or qkvw.ds_tensor.numel( ) < attn_block.attn_qkvw.numel(): pass else: with GatheredParameters([qkvw, dense_w, qkvb, dense_b], modifier_rank=0): qkvw = transpose(qkvw.data) dense_w = transpose(dense_w.data) qkvb = qkvb.data dense_b = dense_b.data else: qkvw.data = transpose(qkvw.data) dense_w.data = transpose(dense_w.data) def _transpose(x): num_attention_heads_per_partition = transformer_config.heads // transformer_config.mp_size attention_head_size = x.shape[-1] // num_attention_heads_per_partition new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, attention_head_size) x_1 = x.view(*new_x_shape) (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1)) if len(q.shape) > 2: return torch.cat((q.reshape(q.shape[0], -1), k.reshape(q.shape[0], -1), v.reshape(q.shape[0], -1)), dim=-1).reshape(x.shape) else: return torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape) if megatron_v2: new_module.config.rotate_half = True new_module.config.rotate_every_two = False # Note: this part needs to be added for BLOOM architecture qkvw = torch.nn.parameter.Parameter(_transpose(qkvw).contiguous()) qkvb = torch.nn.parameter.Parameter(_transpose(qkvb).contiguous()) # NOTE: This part caused instability in the multi-GPU inference! # TODO: This needs to be incorporated in the kernels. #dense_b = dense_b if dense_b is None else dense_b * ( # transformer_config.training_mp_size / transformer_config.mp_size) #_4hh_b = _4hh_b * (transformer_config.training_mp_size / # transformer_config.mp_size) if mlp_linear_layer: if not moe and (_4hh_w.numel() == 0 or _4hh_w.is_meta): if _4hh_w.is_meta or _4hh_w.ds_tensor.numel( ) < mpl_block.inter_w.numel(): pass else: with GatheredParameters([_h4h_w, _4hh_w, _4hh_b, _h4h_b], modifier_rank=0): _h4h_w = transpose(_h4h_w.data) _4hh_w = transpose(_4hh_w.data) _h4h_b = _h4h_b.data _4hh_b = _4hh_b.data else: _h4h_w = [transpose(moe_w1.data) for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data) _4hh_w = [transpose(moe_w1.data) for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data) if moe and moe_type == 'residual': _res_h4h_w.data = transpose(_res_h4h_w.data) _res_4hh_w.data = transpose(_res_4hh_w.data) _res_coef.data = transpose(_res_coef.data) if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta: if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel(): pass else: with GatheredParameters([ attn_block.attn_qkvw, attn_block.attn_qkvb, attn_block.attn_ow, attn_block.attn_ob ], modifier_rank=0): attn_block.attn_qkvw = mp_replace.copy( attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.copy( attn_block.attn_qkvb, qkvb) attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) else: if bigscience_bloom: attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb) else: attn_block.attn_qkvw = mp_replace.qkv_copy( attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.qkv_copy( attn_block.attn_qkvb, qkvb) attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) if moe: gpu_index = dist.get_rank() gpu_index = 0 for ep_index in range(local_ep_size): mpl_block[ep_index].inter_w.data = _h4h_w[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].inter_b.data = _h4h_b[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].output_w.data = _4hh_w[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].output_b.data = _4hh_b[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) new_module.attn_nw.data = attn_nw.to(torch.cuda.current_device()) new_module.attn_nb.data = attn_nb.to(torch.cuda.current_device()) if moe_type == 'residual': new_module.res_mlp.inter_w.data = _res_h4h_w.to( torch.cuda.current_device()) new_module.res_mlp.inter_b.data = _res_h4h_b.to( torch.cuda.current_device()) new_module.res_mlp.output_w.data = _res_4hh_w.to( torch.cuda.current_device()) new_module.res_mlp.output_b.data = _res_4hh_b.to( torch.cuda.current_device()) new_module.res_coef.data = _res_coef.to(torch.cuda.current_device()) else: if _4hh_w.numel() == 0 or _4hh_w.is_meta: if _4hh_w.is_meta or _4hh_w.ds_tensor.numel( ) < mpl_block.inter_w.numel(): pass else: with GatheredParameters([_h4h_w, _4hh_w, _4hh_w, _4hh_b], modifier_rank=0): mpl_block.inter_w = mp_replace.copy( mpl_block.inter_w, _h4h_w) mpl_block.inter_b = mp_replace.copy( mpl_block.inter_b, _h4h_b) mpl_block.output_w = mp_replace.copy( mpl_block.output_w, _4hh_w) mpl_block.output_b = mp_replace.copy( mpl_block.output_b, _4hh_b) else: mpl_block.inter_w = mp_replace.copy(mpl_block.inter_w, _h4h_w) mpl_block.inter_b = mp_replace.copy(mpl_block.inter_b, _h4h_b) mpl_block.output_w = mp_replace.copy(mpl_block.output_w, _4hh_w) mpl_block.output_b = mp_replace.copy(mpl_block.output_b, _4hh_b) if attn_nw is None: new_module.mlp.attn_nw = attn_nw new_module.mlp.attn_nb = attn_nb else: if attn_nw.is_meta or attn_nw.numel() == 0: if attn_nw.is_meta or attn_nw.ds_tensor.numel( ) < new_module.mlp.attn_nw.numel(): pass else: with GatheredParameters([attn_nw, attn_nb], modifier_rank=0): new_module.mlp.attn_nw.data.copy_( attn_nw.to(torch.cuda.current_device())) new_module.mlp.attn_nb.data.copy_( attn_nb.to(torch.cuda.current_device())) else: new_module.mlp.attn_nw.data.copy_( attn_nw.to(torch.cuda.current_device())) new_module.mlp.attn_nb.data.copy_( attn_nb.to(torch.cuda.current_device())) if input_nw.is_meta or input_nw.numel() == 0: if input_nw.is_meta or input_nw.ds_tensor.numel( ) < new_module.norm_w.numel(): pass else: with GatheredParameters([input_nw, input_nb], modifier_rank=0): new_module.norm_w.data.copy_( input_nw.to(torch.cuda.current_device())) new_module.norm_b.data.copy_( input_nb.to(torch.cuda.current_device())) else: new_module.norm_w.data.copy_(input_nw.to(torch.cuda.current_device())) new_module.norm_b.data.copy_(input_nb.to(torch.cuda.current_device())) else: transformer_config = deepspeed.DeepSpeedTransformerConfig( batch_size=micro_batch_size if micro_batch_size > 0 else 1, hidden_size=config.hidden_size, heads=config.num_attention_heads, attn_dropout_ratio=config.attention_probs_dropout_prob, hidden_dropout_ratio=config.hidden_dropout_prob, num_hidden_layers=config.num_hidden_layers, initializer_range=config.initializer_range, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else 1e-12, seed=seed, fp16=fp16, pre_layer_norm=policy.pre_attn_norm, return_tuple=return_tuple, local_rank=local_rank, stochastic_mode=stochastic_mode, normalize_invertible=True, training=training) new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config) new_module.attn_qkvw.data = qkvw new_module.attn_qkvb.data = qkvb new_module.attn_ow.data = dense_w new_module.attn_ob.data = dense_b new_module.attn_nw.data = attn_nw new_module.attn_nb.data = attn_nb new_module.norm_w.data = input_nw new_module.norm_b.data = input_nb new_module.inter_w.data = _h4h_w new_module.inter_b.data = _h4h_b new_module.output_w.data = _4hh_w new_module.output_b.data = _4hh_b return new_module
def replace_transformer_layer(orig_layer_impl, model, policy=None, micro_batch_size=-1, config=None, seed=-1, hidden_size=-1, num_attention_heads=-1, mp_size=1, training_mp_size=1, mp_group=None, ep_group=None, expert_mp_group=None, fp16=True, local_rank=-1, stochastic_mode=True, training=True, quantize=False, quantize_settings=None, triangular_masking=False, return_tuple=True, replace_with_kernel_inject=False, linear_layer_setting=None, moe=False, moe_experts=1, moe_type='standard', checkpoint_dict=None, save_mp_checkpoint_path=None): """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, e.g., transformers.modeling_bert.BertLayer. model (torch.nn.Module): user's nn.module representing their model policy: shows the policy for mapping from the orig_layer_impl to transformer parameters when replace_with_kernel_inject is set, otherwise, it provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection) micro_batch_size (int): micro batch size per gpu used during training/eval config (dict): model config containing hidden size, attention heads, etc. seed (int): random seed value max_seq_length (int): max sequence length for training hidden_size (int): hidden dimension num_attention_heads (int): number of attention heads mp_size (int): model_parallelism degree mp_group : model_parallel group initialized on the modeling side preln (bool): does the original layer implementation do pre or post layer norm? fp16 (bool): fp16 or fp32 local_rank (int): GPU rank (optional), stochastic_mode (bool): whether to use stochastic mode training (bool): specifying whether kernel-injection is done for training/inference (set to false for inference-mode injection) quantize_settings (tuple): this setting shows how we can quantize a model for running it through the inference kernels. It includes (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups). return_tuple (bool): if set, transformer layer returns a tuple as the output. Note: this flag needs to be set for huggingface models. replace_with_kernel_inject (bool): injection_mode, if true, kernels will be add along with configuring Tensor-Parallelism linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers and embedding layers attention_params: (list of strings) [Optional]: shows the parameters in the attention part that needs to be adjusted based on the model-parallelism Returns: Updated nn.module with replaced transformer layers """ mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group, mp_size=mp_size) #, out_dim=0, in_dim=1) def replace_with_policy(child, policy_cls, triangular_masking, inference=False, layer_id=0): policy = policy_cls(child, inference=inference) if inference: hidden_size, num_attention_heads = policy.get_hidden_heads() assert num_attention_heads % mp_size == 0,\ "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\ "This is because the attention computation is partitioned evenly among the parallel GPUs." from deepspeed.moe.layer import MoE moe = False if hasattr(child, 'mlp') and isinstance(child.mlp, MoE): num_experts = child.mlp.num_experts moe = True attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention() if not moe or moe_type == 'standard': mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp() else: mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b, \ _res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type) attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm() if quantize: if policy_cls is not HFBertLayerPolicy: qkvw = qkvw.to(torch.int8) dense_w = dense_w.to(torch.int8) _h4h_w = [moe_w1.to(torch.int8) for moe_w1 in _h4h_w] if moe else _h4h_w.to(torch.int8) _4hh_w = [moe_w1.to(torch.int8) for moe_w1 in _4hh_w] if moe else _4hh_w.to(torch.int8) elif fp16: qkvw = qkvw.half() dense_w = dense_w.half() _h4h_w = [moe_w1.half() for moe_w1 in _h4h_w] if moe else _h4h_w.half() _4hh_w = [moe_w1.half() for moe_w1 in _4hh_w] if moe else _4hh_w.half() if quantize or fp16: qkvb = qkvb if qkvb is None else qkvb.half() dense_b = dense_b if dense_b is None else dense_b.half() _h4h_b = [moe_b1.half() for moe_b1 in _h4h_b] if moe else _h4h_b.half() _4hh_b = [moe_b1.half() for moe_b1 in _4hh_b] if moe else _4hh_b.half() attn_nw = attn_nw if attn_nw is None else attn_nw.half() attn_nb = attn_nb if attn_nb is None else attn_nb.half() input_nw = input_nw.half() input_nb = input_nb.half() if moe and moe_type == 'residual' and fp16: _res_h4h_b = _res_h4h_b.half() _res_4hh_b = _res_4hh_b.half() _res_h4h_w = _res_h4h_w.half() _res_4hh_w = _res_4hh_w.half() _res_coef = _res_coef.half() #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group) if inference: if moe: ep_world_size = dist.get_world_size() local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else 1e-12, fp16=fp16, pre_layer_norm=policy.pre_attn_norm, mp_size=mp_size, q_int8=quantize, moe_experts=local_ep_size, global_experts=num_experts, mlp_type=moe_type) else: rotary_dim = config.rotary_dim if hasattr(config, 'rotary_dim') else child.attention.rotary_ndims \ if hasattr(child, 'attention') and hasattr(child.attention,'rotary_ndims') else -1 bigscience_bloom = policy_cls is BLOOMLayerPolicy transformer_config = transformer_inference.DeepSpeedInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else (config.layer_norm_epsilon if hasattr(config, 'layer_norm_epsilon') else config.layernorm_epsilon if hasattr(config, 'layernorm_epsilon') else 1.0e-12), fp16=fp16, pre_layer_norm=policy.pre_attn_norm, mp_size=mp_size, q_int8=quantize, return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)), triangular_masking=(policy_cls is not HFBertLayerPolicy), local_attention=((config.attention_layers[layer_id] == "local") if hasattr(config, 'attention_layers') else False), window_size=(config.window_size if hasattr(config, 'window_size') else 1), rotary_dim=rotary_dim, mlp_after_attn=(rotary_dim is None or rotary_dim < 0), mlp_act_func_type=policy.mlp_act_func_type, training_mp_size=training_mp_size, bigscience_bloom=bigscience_bloom) if quantize and quantize_settings is not None: (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups) = quantize_settings if moe: new_module = transformer_inference.DeepSpeedMoEInference( transformer_config, mp_group=mp_group, ep_group=None if ep_group is None else ep_group[num_experts], expert_mp_group=None if expert_mp_group is None else expert_mp_group[num_experts], quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) else: new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) if quantize and qkvw.dtype != torch.int8: quantize_bits = 8 quantizer = WeightQuantization() if policy_cls is HFBertLayerPolicy: data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3) else: data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups) qkvw.data.copy_(data_quantized) qkvw.data = qkvw.data.to(torch.int8) else: if moe: new_module = transformer_inference.DeepSpeedMoEInference( transformer_config, mp_group=mp_group, ep_group=None if ep_group is None else ep_group[num_experts], expert_mp_group=None if expert_mp_group is None else expert_mp_group[num_experts], ) else: new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, ) new_module.config.scale_attention = scale_attention # we want the weights in [input, output] shape # linear layer is created with [input, output] shape # transpose it here to reduce inference cost! def transpose(data): # temp move to cpu to avoid requiring extra GPU memory during the reshape data = data.to('cpu') data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1)) data = data.reshape(data.shape[-1], data.shape[-2]) data.to(torch.cuda.current_device()) return data attn_block = new_module.attention mpl_block = new_module.mlp if attn_linear_layer: if qkvw.numel() == 0 or qkvw.is_meta: if qkvw.is_meta or qkvw.ds_tensor.numel( ) < attn_block.attn_qkvw.numel(): pass else: with GatheredParameters([qkvw, dense_w, qkvb, dense_b], modifier_rank=0): qkvw = transpose(qkvw.data) dense_w = transpose(dense_w.data) qkvb = qkvb.data dense_b = dense_b.data else: qkvw.data = transpose(qkvw.data) dense_w.data = transpose(dense_w.data) def _transpose(x): num_attention_heads_per_partition = transformer_config.heads // transformer_config.mp_size attention_head_size = x.shape[-1] // num_attention_heads_per_partition new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, attention_head_size) x_1 = x.view(*new_x_shape) (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1)) if len(q.shape) > 2: return torch.cat((q.reshape(q.shape[0], -1), k.reshape(q.shape[0], -1), v.reshape(q.shape[0], -1)), dim=-1).reshape(x.shape) else: return torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape) if megatron_v2: new_module.config.rotate_half = True new_module.config.rotate_every_two = False # Note: this part needs to be added for BLOOM architecture qkvw = torch.nn.parameter.Parameter(_transpose(qkvw).contiguous()) qkvb = torch.nn.parameter.Parameter(_transpose(qkvb).contiguous()) # NOTE: This part caused instability in the multi-GPU inference! # TODO: This needs to be incorporated in the kernels. #dense_b = dense_b if dense_b is None else dense_b * ( # transformer_config.training_mp_size / transformer_config.mp_size) #_4hh_b = _4hh_b * (transformer_config.training_mp_size / # transformer_config.mp_size) if mlp_linear_layer: if not moe and (_4hh_w.numel() == 0 or _4hh_w.is_meta): if _4hh_w.is_meta or _4hh_w.ds_tensor.numel( ) < mpl_block.inter_w.numel(): pass else: with GatheredParameters([_h4h_w, _4hh_w, _4hh_b, _h4h_b], modifier_rank=0): _h4h_w = transpose(_h4h_w.data) _4hh_w = transpose(_4hh_w.data) _h4h_b = _h4h_b.data _4hh_b = _4hh_b.data else: _h4h_w = [transpose(moe_w1.data) for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data) _4hh_w = [transpose(moe_w1.data) for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data) if moe and moe_type == 'residual': _res_h4h_w.data = transpose(_res_h4h_w.data) _res_4hh_w.data = transpose(_res_4hh_w.data) _res_coef.data = transpose(_res_coef.data) if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta: if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel(): pass else: with GatheredParameters([ attn_block.attn_qkvw, attn_block.attn_qkvb, attn_block.attn_ow, attn_block.attn_ob ], modifier_rank=0): attn_block.attn_qkvw = mp_replace.copy( attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.copy( attn_block.attn_qkvb, qkvb) attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) else: if bigscience_bloom: attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb) else: attn_block.attn_qkvw = mp_replace.qkv_copy( attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.qkv_copy( attn_block.attn_qkvb, qkvb) attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) if moe: gpu_index = dist.get_rank() gpu_index = 0 for ep_index in range(local_ep_size): mpl_block[ep_index].inter_w.data = _h4h_w[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].inter_b.data = _h4h_b[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].output_w.data = _4hh_w[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].output_b.data = _4hh_b[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) new_module.attn_nw.data = attn_nw.to(torch.cuda.current_device()) new_module.attn_nb.data = attn_nb.to(torch.cuda.current_device()) if moe_type == 'residual': new_module.res_mlp.inter_w.data = _res_h4h_w.to( torch.cuda.current_device()) new_module.res_mlp.inter_b.data = _res_h4h_b.to( torch.cuda.current_device()) new_module.res_mlp.output_w.data = _res_4hh_w.to( torch.cuda.current_device()) new_module.res_mlp.output_b.data = _res_4hh_b.to( torch.cuda.current_device()) new_module.res_coef.data = _res_coef.to(torch.cuda.current_device()) else: if _4hh_w.numel() == 0 or _4hh_w.is_meta: if _4hh_w.is_meta or _4hh_w.ds_tensor.numel( ) < mpl_block.inter_w.numel(): pass else: with GatheredParameters([_h4h_w, _4hh_w, _4hh_w, _4hh_b], modifier_rank=0): mpl_block.inter_w = mp_replace.copy( mpl_block.inter_w, _h4h_w) mpl_block.inter_b = mp_replace.copy( mpl_block.inter_b, _h4h_b) mpl_block.output_w = mp_replace.copy( mpl_block.output_w, _4hh_w) mpl_block.output_b = mp_replace.copy( mpl_block.output_b, _4hh_b) else: mpl_block.inter_w = mp_replace.copy(mpl_block.inter_w, _h4h_w) mpl_block.inter_b = mp_replace.copy(mpl_block.inter_b, _h4h_b) mpl_block.output_w = mp_replace.copy(mpl_block.output_w, _4hh_w) mpl_block.output_b = mp_replace.copy(mpl_block.output_b, _4hh_b) if attn_nw is None: new_module.mlp.attn_nw = attn_nw new_module.mlp.attn_nb = attn_nb else: if attn_nw.is_meta or attn_nw.numel() == 0: if attn_nw.is_meta or attn_nw.ds_tensor.numel( ) < new_module.mlp.attn_nw.numel(): pass else: with GatheredParameters([attn_nw, attn_nb], modifier_rank=0): new_module.mlp.attn_nw.data.copy_( attn_nw.to(torch.cuda.current_device())) new_module.mlp.attn_nb.data.copy_( attn_nb.to(torch.cuda.current_device())) else: new_module.mlp.attn_nw.data.copy_( attn_nw.to(torch.cuda.current_device())) new_module.mlp.attn_nb.data.copy_( attn_nb.to(torch.cuda.current_device())) if input_nw.is_meta or input_nw.numel() == 0: if input_nw.is_meta or input_nw.ds_tensor.numel( ) < new_module.norm_w.numel(): pass else: with GatheredParameters([input_nw, input_nb], modifier_rank=0): new_module.norm_w.data.copy_( input_nw.to(torch.cuda.current_device())) new_module.norm_b.data.copy_( input_nb.to(torch.cuda.current_device())) else: new_module.norm_w.data.copy_(input_nw.to(torch.cuda.current_device())) new_module.norm_b.data.copy_(input_nb.to(torch.cuda.current_device())) else: transformer_config = deepspeed.DeepSpeedTransformerConfig( batch_size=micro_batch_size if micro_batch_size > 0 else 1, hidden_size=config.hidden_size, heads=config.num_attention_heads, attn_dropout_ratio=config.attention_probs_dropout_prob, hidden_dropout_ratio=config.hidden_dropout_prob, num_hidden_layers=config.num_hidden_layers, initializer_range=config.initializer_range, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else 1e-12, seed=seed, fp16=fp16, pre_layer_norm=policy.pre_attn_norm, return_tuple=return_tuple, local_rank=local_rank, stochastic_mode=stochastic_mode, normalize_invertible=True, training=training) new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config) new_module.attn_qkvw.data = qkvw new_module.attn_qkvb.data = qkvb new_module.attn_ow.data = dense_w new_module.attn_ob.data = dense_b new_module.attn_nw.data = attn_nw new_module.attn_nb.data = attn_nb new_module.norm_w.data = input_nw new_module.norm_b.data = input_nb new_module.inter_w.data = _h4h_w new_module.inter_b.data = _h4h_b new_module.output_w.data = _4hh_w new_module.output_b.data = _4hh_b return new_module def replace_wo_policy(module, all_reduce_linears): def _replace(child, name, conv_linear_layer): mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) z_inference = (len(list(child.parameters())) > 0) and (list( child.parameters())[0].numel() == 0) if z_inference: weight_shape = child.weight.ds_shape else: weight_shape = child.weight.shape if name in all_reduce_linears: new_weight = torch.empty(( weight_shape[1] if conv_linear_layer else weight_shape[0], (weight_shape[0] if conv_linear_layer else weight_shape[1]) // mp_size, ), device=child.weight.device, dtype=child.weight.dtype) if z_inference: with deepspeed.zero.GatheredParameters(child.weight, modifier_rank=0): data = child.weight.data.to(new_weight.device) if conv_linear_layer: data = data.transpose(-1, -2).contiguous() data = mp_replace.copy(new_weight, data) child.weight.ds_tensor = torch.empty(1) else: if conv_linear_layer: child.weight.data = child.weight.data.transpose(-1, -2).contiguous() data = mp_replace.copy(new_weight, child.weight.data) new_bias = torch.empty((weight_shape[0]), device=child.weight.device, dtype=child.weight.dtype) if z_inference: with deepspeed.zero.GatheredParameters(child.bias, modifier_rank=0): new_bias.data.copy_(child.bias.data) elif child.bias is not None: new_bias.data.copy_(child.bias.data) return LinearAllreduce(data, child.bias if child.bias is None else \ torch.nn.parameter.Parameter(new_bias.to(torch.cuda.current_device())), mp_group) else: new_weight = torch.empty(( (weight_shape[1] if conv_linear_layer else weight_shape[0]) // mp_size, weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1], ), device=child.weight.device, dtype=child.weight.dtype) if z_inference: with deepspeed.zero.GatheredParameters(child.weight, modifier_rank=0): data = child.weight.data.to(new_weight.device) if conv_linear_layer: data = data.transpose(-1, -2).contiguous() data = mp_replace.copy(new_weight, data) child.weight.ds_tensor = torch.empty(1) else: if conv_linear_layer: child.weight.data = child.weight.data.transpose(-1, -2).contiguous() data = mp_replace.copy(new_weight, child.weight.data) new_bias = torch.empty((weight_shape[0] // mp_size), device=child.weight.device, dtype=child.weight.dtype) if z_inference: with deepspeed.zero.GatheredParameters(child.bias, modifier_rank=0): bias_data = None if child.bias is None else mp_replace.copy( new_bias, child.bias.data).to(torch.cuda.current_device()) else: bias_data = None if child.bias is None else mp_replace.copy( new_bias, child.bias.data).to(torch.cuda.current_device()) return LinearLayer(weight=data.to(torch.cuda.current_device()), bias=bias_data) def _slice_embedding(child, name, conv_linear_layer): mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) new_weight = torch.empty((child.weight.shape[0], child.weight.shape[1] // mp_size), device=child.weight.device, dtype=child.weight.dtype) data = mp_replace.copy(new_weight, child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \ child.weight.data) new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // mp_size) new_embedding.weight.data.copy_(data) return new_embedding def update_mp_params(child): if hasattr(child, 'n_heads'): child.n_heads = child.n_heads // mp_size if hasattr(child, 'inner_dim'): child.inner_dim = child.inner_dim // mp_size if hasattr(child, 'num_heads'): child.num_heads = child.num_heads // mp_size if hasattr(child, 'num_attention_heads'): child.num_attention_heads = child.num_attention_heads // mp_size if hasattr(child, 'all_head_size'): child.all_head_size = child.all_head_size // mp_size if hasattr(child, 'embed_dim'): child.embed_dim = child.embed_dim // mp_size if hasattr(child, 'hidden_size'): child.hidden_size = child.hidden_size // mp_size conv_linear_layer = False if linear_layer_setting is not None: linear_policies = {linear_layer_setting[0]: _replace} if len(linear_layer_setting) == 2: linear_policies.update({linear_layer_setting[1]: _slice_embedding}) else: if orig_layer_impl is HFGPT2LayerPolicy._orig_layer_class: try: import transformers conv_linear_layer = True linear_policies = {transformers.model_utils.Conv1D: _replace} except ImportError: linear_policies = {nn.Linear: _replace} else: linear_policies = {nn.Linear: _replace, nn.Embedding: _slice_embedding} def _replace_module(r_module, prev_name=''): for name, child in r_module.named_children(): if child.__class__ in linear_policies: setattr( r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name, conv_linear_layer)) else: update_mp_params(child) _replace_module(child, name) return r_module return _replace_module(module) def replace_fn(child, _policy, layer_id=0): if training: # copy relevant state from child -> new module new_module = replace_with_policy(child, _policy, triangular_masking) else: # copy relevant state from child -> new module if replace_with_kernel_inject: new_module = replace_with_policy(child, _policy, triangular_masking, inference=True, layer_id=layer_id) else: new_module = replace_wo_policy(child, _policy) return new_module replaced_module = replace_module(model=model, orig_class=orig_layer_impl, replace_fn=replace_fn, _replace_policy=policy) world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 if checkpoint_dict is not None: start_time = time.time() checkpoint = checkpoint_dict['checkpoints'] ckpt_type = checkpoint_dict.get('parallelization', 'pp') ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size) base_dir = checkpoint_dict.get('base_dir', '') if ckpt_type == 'pp': pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") for i in range(len(checkpoint)): if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: pbar.update(1) sd = torch.load(checkpoint[i], map_location='cpu') load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type) else: num_checkpoints = len(checkpoint) // ckpt_mp_size assert world_size >= ckpt_mp_size,\ "Currently, merging checkpoints is not supported (when world_size is smaller than #checkpoints)!" checkpoint_stride = world_size // ckpt_mp_size if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: pbar = tqdm.tqdm(total=num_checkpoints, desc=f"Loading {num_checkpoints} checkpoint shards") for i in range(num_checkpoints): if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: pbar.update(1) ckpt_index = i * ckpt_mp_size + (rank // checkpoint_stride) ckpt_file = os.path.join( base_dir, checkpoint[ckpt_index]) if base_dir else checkpoint[ckpt_index] sd = torch.load(ckpt_file, map_location='cpu') load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type, rank % (world_size // ckpt_mp_size)) print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") if save_mp_checkpoint_path is not None: from collections import OrderedDict import json if checkpoint_dict is None: ckpt_name = "ds_model" try: from transformers.models.bloom.modeling_bloom import BloomForCausalLM if isinstance(model, BloomForCausalLM): ckpt_name = "bloom" except ImportError: ckpt_name = "ds_model" else: ckpt_name = checkpoint_dict['type'] if dist.is_initialized(): dist.barrier() transformer_name = get_transformer_name(replaced_module) non_tp_ckpt_name = f'{ckpt_name}-non-tp.pt' ckpt_files = [non_tp_ckpt_name] * world_size os.makedirs(save_mp_checkpoint_path, exist_ok=True) if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") torch.save( OrderedDict({ k: v for k, v in dict(replaced_module.state_dict()).items() if transformer_name not in k }), f'{save_mp_checkpoint_path}/{non_tp_ckpt_name}') ckpt_files += [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)] config = json.dumps({ 'type': ckpt_name, 'base_dir': f'{save_mp_checkpoint_path}', 'checkpoints': ckpt_files, 'version': 1.0, 'parallelization': 'tp', 'mp_size': world_size }) with open(f"{save_mp_checkpoint_path}/{ckpt_name}_ds-inference_config.json", "w") as cfg: cfg.write(config) torch.save( OrderedDict({ k: v for k, v in dict(replaced_module.state_dict()).items() if transformer_name in k }), f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{rank:0>2d}.pt') return replaced_module
def test_world_size_1(self): assert dist.get_world_size() == 1
def test_two(self, number, color="purple"): assert dist.get_world_size() == 2 assert number == 1138 assert color == "purple"
def test_one(self, number): assert dist.get_world_size() == 2 assert number == 1138
def test(self): assert dist.is_initialized() assert dist.get_world_size() == 3 assert dist.get_rank() < 3
from deepspeed.utils.timer import SynchronizedWallClockTimer from statistics import mean timers = SynchronizedWallClockTimer() parser = argparse.ArgumentParser() parser.add_argument('--local_rank', type=int, default=-1) args = parser.parse_args() deepspeed.init_distributed(dist_backend='nccl') args.local_rank = int(os.environ['LOCAL_RANK']) torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) size = dist.get_world_size() rank = dist.get_rank() backend = NcclBackend() local_rank = args.local_rank # Setting tensor_size (BERT-Large) tensor_size = 300 * 2**20 server_size = int(tensor_size / size) if tensor_size % (8 * size) != 0: right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) else: right_tensor_size = tensor_size right_server_size = right_tensor_size // size # Adding bias to the initialization of the gradient we are communicating
def _is_grid_valid(self): ranks = 1 for ax in self._topo.get_axis_names(): ranks *= self._topo.get_dim(ax) return ranks == dist.get_world_size()
def test(self, tmpdir): from deepspeed.runtime.comm.nccl import NcclBackend size = dist.get_world_size() rank = dist.get_rank() backend = NcclBackend() local_rank = dist.get_rank() device = torch.device("cuda", dist.get_rank()) # A simulated compression function using deepspeed.comm def torch_sim(a): a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) scale = a.norm() / np.sqrt(a.numel()) a_compressed = scale * a_sign a_sign = None worker_error = a - a_compressed dist.all_reduce(a_compressed) a_compressed.mul_(1 / dist.get_world_size()) a_server_sign = (a_compressed.sign().add_(1).bool().float().add_( -0.5).mul_(2.0)) a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) server_scale = [ chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list ] a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) a_server_compressed = torch.cat([ server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size()) ]) rank = dist.get_rank() server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] torch.cuda.synchronize() dist.barrier() return a_server_compressed, worker_error, server_error tensor_size = 300 * 2**20 server_size = int(tensor_size / size) if tensor_size % (8 * size) != 0: right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) else: right_tensor_size = tensor_size right_server_size = right_tensor_size // size # Adding bias to the initialization of the gradient we are communicating # In order to get rid of the case where some elements in the gradient are too small a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank worker_error = torch.zeros(right_tensor_size, device=device) server_error = torch.zeros(right_server_size, device=device) a_torch, worker_error_torch, server_error_torch = torch_sim(a) torch.cuda.empty_cache() a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank) threshold = 1e-6 magnitude_threshold = 1e-6 diff_mask = (a_after - a_torch) > threshold diff_server_mask = torch.chunk(diff_mask, size)[rank] mpi_server = torch.chunk(a_after, size)[rank] + server_error torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch # If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic # The test would skip those numbers that are too small in compensated_server_m check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold if torch.sum(check_mag_mask) != 0: print("Fails at {} of positions".format(torch.sum(check_mag_mask))) assert torch.sum(diff_server_mask) == 0 or torch.sum( check_mag_mask) == 0
def forward(ctx, input, input_mask, head_mask, layer_past, get_present, encoder_hidden_states, encoder_attention_mask, output_attentions, norm_w, norm_b, config, attn_qkvw, attn_qkvb, num_attention_heads_per_partition, norm_factor, hidden_size_per_partition, attn_ow, attn_ob, mp_group, q_scales, q_groups, merge_count, qkv_merging, score_context_func, alibi): def _transpose_for_scores(x, key=False, reshape=False): attention_head_size = x.shape[ -1] // num_attention_heads_per_partition new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, attention_head_size) x_1 = x.view(*new_x_shape) if key: x_1 = x_1.permute(0, 2, 3, 1) else: x_1 = x_1.permute(0, 2, 1, 3) if reshape: return x_1.reshape(x.shape) return x_1.contiguous() def _transpose_for_context(x): x = x.permute(0, 2, 1, 3).contiguous() new_x_layer_shape = x.size()[:-2] + \ (hidden_size_per_partition,) return x.view(*new_x_layer_shape).contiguous() ########### This part is taken/modified form the HF modeling_bloom.py ################ # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=True): """Split a tensor along its last dimension. Args: tensor: ([`torch.tensor`], *required*): input tensor to split num_partitions ([`int`], *required*): number of partitions to split the tensor contiguous_split_chunks ([`bool`], *optional*, default=`False`):: If True, make each chunk contiguous in memory. """ # Get the size and dimension. last_dim = tensor.dim() - 1 numerator, denominator = tensor.size()[last_dim], num_partitions if not (numerator % denominator == 0): raise ValueError( f"{numerator} is not divisible by {denominator}") last_dim_size = numerator // denominator # Split. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) # Note: torch.split does not create contiguous tensors by default. if contiguous_split_chunks: return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): alibi = alibi.to(torch.cuda.current_device()) head_dim = hidden_size_per_partition // num_attention_heads_per_partition new_tensor_shape = mixed_x_layer.size()[:-1] + ( num_attention_heads_per_partition, 3 * head_dim) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) if layer_past is not None: past_key, past_value = layer_past # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) value_layer = torch.cat( (past_value.type_as(value_layer), value_layer), dim=1) presents = (key_layer, value_layer) # [batch_size, head_dim, q_length, k_length] output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] query_layer = query_layer.transpose(1, 0).reshape( output_size[2], output_size[0] * output_size[1], -1) # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] key_layer = key_layer.transpose(1, 0).reshape( output_size[3], output_size[0] * output_size[1], -1) # Raw attention scores. [batch_size * num_heads, q_length, k_length] matmul_result = torch.matmul( query_layer.transpose(1, 0), key_layer.transpose(1, 0).transpose(1, 2)) # change view to [batch_size, num_heads, q_length, k_length] attention_scores = matmul_result.view(*output_size) offset = dist.get_rank( ) * num_attention_heads_per_partition if dist.is_initialized( ) else 0 attention_probs = inference_cuda_module.softmax_fp16( attention_scores, ((1 - input_mask).half() * minus_inf) if input_mask.dtype == torch.int64 else input_mask, alibi, (config.triangular_masking and (attention_scores.shape[-2] > 1)), False, False, 1, False, 1 / (norm_factor * norm_factor), offset, config.mp_size) # change view [batch_size x num_heads, q_length, k_length] attention_probs_reshaped = attention_probs.view( *matmul_result.shape) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm( attention_probs_reshaped, value_layer.transpose(1, 2).reshape(-1, value_layer.size(1), value_layer.size(3))) # change view [batch_size, num_heads, q_length, head_dim] context_layer = context_layer.view( context_layer.size(0) // num_attention_heads_per_partition, num_attention_heads_per_partition, context_layer.size(1), context_layer.shape[-1]) context_layer = _transpose_for_context(context_layer) return context_layer, presents ###################### End of HF modeling_bloom addition ######################## def compute_attention(qkv_out, input_mask): no_masking = input_mask is None head_size = (qkv_out.shape[-1] // 3 // num_attention_heads_per_partition) if no_masking: input_mask = torch.empty(1) if merge_count > 0 and config.q_int8: split_dim = (qkv_out.dim() - 1) qkv_split = torch.split(qkv_out, (qkv_out.shape[-1] // (2**merge_count)), dim=split_dim) qkv_split = [ torch.split(s, (s.shape[-1] // 3), dim=split_dim) for s in qkv_split ] (mixed_query, key_layer, value_layer) = [ torch.cat([s[i] for s in qkv_split], axis=-1) for i in range(len(qkv_split[0])) ] if config.rotary_dim > 0: mixed_query, key_layer = inference_cuda_module.apply_rotary_pos_emb( mixed_query, key_layer, config.rotary_dim, 0 if layer_past is None else layer_past[0].shape[-2], num_attention_heads_per_partition, config.rotate_half, config.rotate_every_two) if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat( (past_key.type_as(key_layer), key_layer), dim=-2) value_layer = torch.cat( (past_value.type_as(value_layer), value_layer), dim=-2) presents = (key_layer, value_layer) mixed_query = _transpose_for_scores(mixed_query, False, True) key_layer = _transpose_for_scores(key_layer, True, True) / ( norm_factor if config.scale_attention else 1.0) value_layer = _transpose_for_scores(value_layer, False, True) if layer_past is None: attn_key_value = score_context_func( mixed_query, key_layer, torch.empty(1), ((1 - input_mask).half() * minus_inf) if input_mask.dtype == torch.int64 else input_mask, value_layer, torch.empty(1), num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), (not unfused_mode), # noqa: F821 config.triangular_masking, config.local_attention, config.window_size, no_masking) else: attn_key_value = score_context_func( mixed_query, (key_layer if unfused_mode else past_key.type_as(key_layer)), # noqa: F821 key_layer, ((1 - input_mask).half() * minus_inf) if input_mask.dtype == torch.int64 else input_mask, (value_layer if unfused_mode else past_value.type_as(value_layer)), # noqa: F821 value_layer, num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), (not unfused_mode), # noqa: F821 config.triangular_masking, config.local_attention, config.window_size, no_masking) if unfused_mode: # noqa: F821 context_layer, _, _ = attn_key_value else: context_layer, key_layer, value_layer = attn_key_value # Transpose Context context_layer = _transpose_for_context(context_layer) return context_layer, presents[0], presents[ 1] # atten_output, key_layer, value_layer else: # Note: This modification is added for the BLOOM-176B model and will be removed later! if config.bigscience_bloom: context_layer, presents = backup_attention( qkv_out, layer_past, alibi, input_mask, norm_factor) return context_layer, presents[0], presents[ 1] #key_layer, value_layer else: if alibi is not None: batch_heads = qkv_out.shape[ 0] * num_attention_heads_per_partition offset = dist.get_rank( ) * batch_heads if dist.is_initialized() else 0 sliced_alibi = alibi[offset:batch_heads + offset, :, :] attn_key_value = score_context_func( qkv_out, ((1 - input_mask).to(qkv_out.dype) * minus_inf) if input_mask.dtype == torch.int64 else input_mask, config.rotary_dim, config.rotate_half, config.rotate_every_two, num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), config.triangular_masking, config.local_attention, config.window_size, no_masking, config.layer_id, DeepSpeedTransformerInference.layer_id, sliced_alibi if alibi is not None else torch.empty(1)) context_layer, key_layer, value_layer = attn_key_value return context_layer, key_layer, value_layer def selfAttention_fp(): vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \ inference_cuda_module.vector_matmul_fp32 if not config.pre_layer_norm: linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ inference_cuda_module.linear_layer_fp32 qkv_out = linear_func(input, attn_qkvw, attn_qkvb, DeepSpeedTransformerInference.layer_id) else: qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \ inference_cuda_module.qkv_gemm_fp32 qkv_out = qkv_func( input, attn_qkvw, (attn_qkvb if attn_qkvb is not None else norm_b), norm_w, norm_b, config.epsilon, (attn_qkvb is not None), 1 if config.bigscience_bloom else DeepSpeedTransformerInference.layer_id) context_layer, key_layer, value_layer = compute_attention( qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask) output = vector_matmul_func(context_layer, attn_ow, False) return output, key_layer, value_layer, context_layer, qkv_out[-1] def selfAttention_int8(): if not config.pre_layer_norm: qkv_out = inference_cuda_module.linear_layer_int8( input, attn_qkvw, attn_qkvb, q_scales[0], (q_groups * (3 if qkv_merging else 1) * (2**merge_count))) else: qkv_out = inference_cuda_module.qkv_gemm_int8( input, attn_qkvw, attn_qkvb, norm_w, norm_b, config.epsilon, q_scales[0], (q_groups * (3 if qkv_merging else 1) * (2**merge_count)), (attn_qkvb is not None)) context_layer, key_layer, value_layer = compute_attention(qkv_out) output = inference_cuda_module.vector_matmul_int8( context_layer, attn_ow, q_scales[1], q_groups, (merge_count)) return output, key_layer, value_layer, context_layer if config.q_int8: output, key_layer, value_layer, context_layer = selfAttention_int8( ) else: output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp( ) if config.mlp_after_attn and mp_group is not None and dist.get_world_size( group=mp_group) > 1: dist.all_reduce(output, group=mp_group) return (output, key_layer, value_layer, context_layer, inp_norm)