def print_json_dist(message, ranks=None, path=None): from deepspeed import comm as dist """Print message when one of following condition meets + not dist.is_initialized() + dist.get_rank() in ranks if ranks is not None or ranks = [-1] Args: message (str) ranks (list) path (str) """ should_log = not dist.is_initialized() ranks = ranks or [] my_rank = dist.get_rank() if dist.is_initialized() else -1 if ranks and not should_log: should_log = ranks[0] == -1 should_log = should_log or (my_rank in set(ranks)) if should_log: message['rank'] = my_rank import json with open(path, 'w') as outfile: json.dump(message, outfile) os.fsync(outfile)
def see_memory_usage(message, force=False): if not force: return if dist.is_initialized() and not dist.get_rank() == 0: return # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports gc.collect() # Print message except when distributed but not rank 0 logger.info(message) logger.info( f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \ CA {round(torch_memory_reserved() / (1024 * 1024 * 1024),2)} GB \ Max_CA {round(torch_max_memory_reserved() / (1024 * 1024 * 1024))} GB " ) vm_stats = psutil.virtual_memory() used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2) logger.info( f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%' ) # get the peak memory to report correct data, so reset the counter for the next call if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ torch.cuda.reset_peak_memory_stats()
def _dist_init(self, local_rank, num_procs, skip_msg): """Initialize deepspeed.comm and execute the user function. """ if self.set_dist_env: os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = get_master_port() os.environ['LOCAL_RANK'] = str(local_rank) # NOTE: unit tests don't support multi-node so local_rank == global rank os.environ['RANK'] = str(local_rank) os.environ['WORLD_SIZE'] = str(num_procs) # turn off NCCL logging if set os.environ.pop('NCCL_DEBUG', None) set_cuda_visibile() if self.init_distributed: deepspeed.init_distributed(dist_backend=self.backend) dist.barrier() if torch.cuda.is_available(): torch.cuda.set_device(local_rank) try: self.current_test(**self.test_kwargs) except BaseException as e: if isinstance(e, Skipped): skip_msg.put(e.msg) else: raise e if self.init_distributed or dist.is_initialized(): # make sure all ranks finish at the same time dist.barrier() # tear down after test completes dist.destroy_process_group()
def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): is_pipe_parallel = isinstance(self.module, PipelineModule) if is_pipe_parallel: raise RuntimeError( 'pipeline parallelism is currently not supported in inference.') if os.path.isdir(load_dir): if tag is None: latest_path = os.path.join(load_dir, "latest") if os.path.isfile(latest_path): with open(latest_path, "r") as fd: tag = fd.read().strip() ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine) else: sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir) if type(sd_loader) is list: self.sd = torch.load(sd_loader[0], map_location='cpu') self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) for i in range(1, len(sd_loader)): if not dist.is_initialized() or dist.get_rank() == 0: print(f"loading checkpoint ({i})") self.sd = torch.load(sd_loader[i], map_location='cuda') self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) else: mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() load_path, checkpoint, quantize_config = sd_loader.load(self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel, quantize=(self.dtype is torch.int8), quantize_groups=self.quantize_groups, mlp_extra_grouping=self.mlp_extra_grouping) self.quantization_scales, self.quantize_merge_count = quantize_config moe, _ = has_moe_layers(self.module) if moe: from deepspeed.runtime.engine import DeepSpeedEngine old_moe_load = False if not isinstance(checkpoint['num_experts'], list): old_moe_load = True DeepSpeedEngine.load_moe_state_dict( load_dir, tag, state_dict=checkpoint[self._choose_module_key(checkpoint)], old_moe_load=old_moe_load, model=self.module, mpu=self.mpu, checkpoint_engine=self.checkpoint_engine) self.module.load_state_dict( state_dict=checkpoint[self._choose_module_key(checkpoint)], checkpoint_engine=self.checkpoint_engine, strict=load_module_strict)
def create_deepspeed_args(): parser = argparse.ArgumentParser() args = parser.parse_args(args='') args.deepspeed = True if dist.is_initialized(): # We assume up to one full node executing unit tests assert dist.get_world_size() <= torch.cuda.device_count() args.local_rank = dist.get_rank() return args
def _get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" assert dist.is_initialized(), \ 'dist is not initialized' global mpu if mpu is not None: return mpu.get_data_parallel_group() # Return the clone of dist world group return _clone_world_group()
def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu): """ Create expert and data parallel groups based on MPU (model parallel) group. Note: Caller of this function is responsible to check if the groups already exist. Example - E + M + D parallel world_size = 16 model_degree = 2 expert_degree = 4 # number of experts in same group mp_group = [0, 1], [2,3], [4,5] ... data_parallel_group =[0,2,4,6,8,10, 12,14], [1,3,5,7,9,11,13,15] expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15] expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15] """ assert dist.is_initialized(), "dist is not initialized" model_parallel_size_ = mpu.get_model_parallel_world_size() global expert_tensor_parallel_world_size expert_tensor_parallel_world_size = model_parallel_size_ world_size = dist.get_world_size() rank = dist.get_rank() dp_world_size = mpu.get_data_parallel_world_size() dp_rank = mpu.get_data_parallel_rank() _ensure_divisibility(world_size, model_parallel_size_) _ensure_divisibility(dp_world_size, expert_parallel_size_) log_dist( f"Creating deepspeed groups with model parallel size {model_parallel_size_}, expert parallel size {expert_parallel_size_}, world size {world_size}, dp world size {dp_world_size}", [0]) global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP # Get world size and rank. Ensure some consistencies. _DATA_PARALLEL_GROUP = mpu.get_data_parallel_group() _MODEL_PARALLEL_GROUP = mpu.get_model_parallel_group() group_name = f"ep_size_{expert_parallel_size_}" # Only create groups if they don't already exist # Need to check conditions outside the group creation loop because of the way torch.dist group creation works if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP: expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks( world_size, model_parallel_size_, expert_parallel_size_) for ranks in expert_parallel_groups: group = dist.new_group(ranks) if rank in list(ranks): _EXPERT_PARALLEL_GROUP[group_name] = group for ranks in expert_data_parallel_groups: group = dist.new_group(ranks) if rank in list(ranks): _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
def __init__(self, typename, *module_args, **module_kwargs): self.typename = typename self.module_args = module_args self.module_kwargs = module_kwargs if not issubclass(typename, nn.Module): raise RuntimeError('LayerSpec only supports torch.nn.Module types.') if dist.is_initialized(): self.global_rank = dist.get_rank() else: self.global_rank = -1
def _create_expert_and_data_parallel(expert_parallel_size_): """ Create expert and data parallel groups. Note: Caller of this function is responsible to check if the groups already exist. Example - E + D parallel world_size = 16 expert_parallel_size = 2 # number of experts in same group expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE """ assert dist.is_initialized() log_dist( f'Creating expert and data parallel groups with size {expert_parallel_size_}', ranks=[0]) world_size = dist.get_world_size() rank = dist.get_rank() _ensure_divisibility(world_size, expert_parallel_size_) group_name = f"ep_size_{expert_parallel_size_}" # Build the expert data parallel groups. global _EXPERT_DATA_PARALLEL_GROUP # Only create group if it does not already exist if group_name not in _EXPERT_DATA_PARALLEL_GROUP: for i in range(expert_parallel_size_): ranks = range(i, world_size, expert_parallel_size_) group = dist.new_group(ranks) log_dist( f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}', [0]) if i == (rank % expert_parallel_size_): _EXPERT_DATA_PARALLEL_GROUP[group_name] = group # Build the expert parallel groups. global _EXPERT_PARALLEL_GROUP # Only create group if it does not already exist if group_name not in _EXPERT_PARALLEL_GROUP: for i in range(world_size // expert_parallel_size_): ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_) group = dist.new_group(ranks) log_dist( f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0]) if i == (rank // expert_parallel_size_): _EXPERT_PARALLEL_GROUP[group_name] = group
def _clone_world_group(): """Create a clone of the world group Note: We need to clone the dist world group because we use dist.get_global_rank() utility function in DeepSpeed at many places. As that function does not work on dist.group.WORLD, we need to keep a clone of it. """ assert dist.is_initialized(), "dist is not initialized" global _WORLD_GROUP if _WORLD_GROUP is None: # If not cloned already, clone the world group _WORLD_GROUP = dist.new_group(ranks=range(dist.get_world_size())) return _WORLD_GROUP
def log_dist(message, ranks=None, level=logging.INFO): from deepspeed import comm as dist """Log message when one of following condition meets + not dist.is_initialized() + dist.get_rank() in ranks if ranks is not None or ranks = [-1] Args: message (str) ranks (list) level (int) """ should_log = not dist.is_initialized() ranks = ranks or [] my_rank = dist.get_rank() if dist.is_initialized() else -1 if ranks and not should_log: should_log = ranks[0] == -1 should_log = should_log or (my_rank in set(ranks)) if should_log: final_message = "[Rank {}] {}".format(my_rank, message) logger.log(level, final_message)
def save_exp_results_to_database(self, message, ranks=None, path=None): """Print message when one of following condition meets + not dist.is_initialized() + dist.get_rank() in ranks if ranks is not None or ranks = [-1] Args: message (str) ranks (list) path (str) """ should_log = not dist.is_initialized() ranks = ranks or [] my_rank = dist.get_rank() if dist.is_initialized() else -1 if ranks and not should_log: should_log = ranks[0] == -1 should_log = should_log or (my_rank in set(ranks)) logger.debug(f"*** Should log: {should_log}") if should_log: message['rank'] = my_rank with open(path, 'a') as outfile: json.dump(message, outfile) outfile.write('\n')
def _create_model_parallel(model_parallel_size_): """ Initialize model data parallel groups. Arguments: model_parallel_size: number of GPUs used to parallelize model. Returns: Tuple of data parallel group and model parallel group Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model. The present function will create 4 model parallel groups and 2 data parallel groups as: 4 model parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7] 2 data parallel groups: [g0, g2, g4, g6], [g1, g3, g5, g7] Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ log_dist(f'Creating model parallel group with size {model_parallel_size_}', ranks=[0]) # Get world size and rank. Ensure some consistencies. assert dist.is_initialized() world_size = dist.get_world_size() model_parallel_size = min(model_parallel_size_, world_size) _ensure_divisibility(world_size, model_parallel_size) rank = dist.get_rank() _DATA_PARALLEL_GROUP = None _MODEL_PARALLEL_GROUP = None # Build the data parallel groups. for i in range(model_parallel_size): ranks = range(i, world_size, model_parallel_size) group = dist.new_group(ranks) if i == (rank % model_parallel_size): _DATA_PARALLEL_GROUP = group # Build the model parallel groups. for i in range(world_size // model_parallel_size): ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) group = dist.new_group(ranks) if i == (rank // model_parallel_size): _MODEL_PARALLEL_GROUP = group return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP
def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): alibi = alibi.to(torch.cuda.current_device()) head_dim = hidden_size_per_partition // num_attention_heads_per_partition new_tensor_shape = mixed_x_layer.size()[:-1] + ( num_attention_heads_per_partition, 3 * head_dim) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) if layer_past is not None: past_key, past_value = layer_past # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) value_layer = torch.cat( (past_value.type_as(value_layer), value_layer), dim=1) presents = (key_layer, value_layer) # [batch_size, head_dim, q_length, k_length] output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] query_layer = query_layer.transpose(1, 0).reshape( output_size[2], output_size[0] * output_size[1], -1) # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] key_layer = key_layer.transpose(1, 0).reshape( output_size[3], output_size[0] * output_size[1], -1) # Raw attention scores. [batch_size * num_heads, q_length, k_length] matmul_result = torch.matmul( query_layer.transpose(1, 0), key_layer.transpose(1, 0).transpose(1, 2)) # change view to [batch_size, num_heads, q_length, k_length] attention_scores = matmul_result.view(*output_size) offset = dist.get_rank( ) * num_attention_heads_per_partition if dist.is_initialized( ) else 0 attention_probs = inference_cuda_module.softmax_fp16( attention_scores, ((1 - input_mask).half() * minus_inf) if input_mask.dtype == torch.int64 else input_mask, alibi, (config.triangular_masking and (attention_scores.shape[-2] > 1)), False, False, 1, False, 1 / (norm_factor * norm_factor), offset, config.mp_size) # change view [batch_size x num_heads, q_length, k_length] attention_probs_reshaped = attention_probs.view( *matmul_result.shape) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm( attention_probs_reshaped, value_layer.transpose(1, 2).reshape(-1, value_layer.size(1), value_layer.size(3))) # change view [batch_size, num_heads, q_length, head_dim] context_layer = context_layer.view( context_layer.size(0) // num_attention_heads_per_partition, num_attention_heads_per_partition, context_layer.size(1), context_layer.shape[-1]) context_layer = _transpose_for_context(context_layer) return context_layer, presents
def get_ma_status(): if dist.is_initialized() and not dist.get_rank() == 0: return 0 return torch.cuda.memory_allocated()
def replace_transformer_layer(orig_layer_impl, model, policy=None, micro_batch_size=-1, config=None, seed=-1, hidden_size=-1, num_attention_heads=-1, mp_size=1, training_mp_size=1, mp_group=None, ep_group=None, expert_mp_group=None, fp16=True, local_rank=-1, stochastic_mode=True, training=True, quantize=False, quantize_settings=None, triangular_masking=False, return_tuple=True, replace_with_kernel_inject=False, linear_layer_setting=None, moe=False, moe_experts=1, moe_type='standard', checkpoint_dict=None, save_mp_checkpoint_path=None): """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, e.g., transformers.modeling_bert.BertLayer. model (torch.nn.Module): user's nn.module representing their model policy: shows the policy for mapping from the orig_layer_impl to transformer parameters when replace_with_kernel_inject is set, otherwise, it provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection) micro_batch_size (int): micro batch size per gpu used during training/eval config (dict): model config containing hidden size, attention heads, etc. seed (int): random seed value max_seq_length (int): max sequence length for training hidden_size (int): hidden dimension num_attention_heads (int): number of attention heads mp_size (int): model_parallelism degree mp_group : model_parallel group initialized on the modeling side preln (bool): does the original layer implementation do pre or post layer norm? fp16 (bool): fp16 or fp32 local_rank (int): GPU rank (optional), stochastic_mode (bool): whether to use stochastic mode training (bool): specifying whether kernel-injection is done for training/inference (set to false for inference-mode injection) quantize_settings (tuple): this setting shows how we can quantize a model for running it through the inference kernels. It includes (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups). return_tuple (bool): if set, transformer layer returns a tuple as the output. Note: this flag needs to be set for huggingface models. replace_with_kernel_inject (bool): injection_mode, if true, kernels will be add along with configuring Tensor-Parallelism linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers and embedding layers attention_params: (list of strings) [Optional]: shows the parameters in the attention part that needs to be adjusted based on the model-parallelism Returns: Updated nn.module with replaced transformer layers """ mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group, mp_size=mp_size) #, out_dim=0, in_dim=1) def replace_with_policy(child, policy_cls, triangular_masking, inference=False, layer_id=0): policy = policy_cls(child, inference=inference) if inference: hidden_size, num_attention_heads = policy.get_hidden_heads() assert num_attention_heads % mp_size == 0,\ "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\ "This is because the attention computation is partitioned evenly among the parallel GPUs." from deepspeed.moe.layer import MoE moe = False if hasattr(child, 'mlp') and isinstance(child.mlp, MoE): num_experts = child.mlp.num_experts moe = True attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention() if not moe or moe_type == 'standard': mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp() else: mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b, \ _res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type) attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm() if quantize: if policy_cls is not HFBertLayerPolicy: qkvw = qkvw.to(torch.int8) dense_w = dense_w.to(torch.int8) _h4h_w = [moe_w1.to(torch.int8) for moe_w1 in _h4h_w] if moe else _h4h_w.to(torch.int8) _4hh_w = [moe_w1.to(torch.int8) for moe_w1 in _4hh_w] if moe else _4hh_w.to(torch.int8) elif fp16: qkvw = qkvw.half() dense_w = dense_w.half() _h4h_w = [moe_w1.half() for moe_w1 in _h4h_w] if moe else _h4h_w.half() _4hh_w = [moe_w1.half() for moe_w1 in _4hh_w] if moe else _4hh_w.half() if quantize or fp16: qkvb = qkvb if qkvb is None else qkvb.half() dense_b = dense_b if dense_b is None else dense_b.half() _h4h_b = [moe_b1.half() for moe_b1 in _h4h_b] if moe else _h4h_b.half() _4hh_b = [moe_b1.half() for moe_b1 in _4hh_b] if moe else _4hh_b.half() attn_nw = attn_nw if attn_nw is None else attn_nw.half() attn_nb = attn_nb if attn_nb is None else attn_nb.half() input_nw = input_nw.half() input_nb = input_nb.half() if moe and moe_type == 'residual' and fp16: _res_h4h_b = _res_h4h_b.half() _res_4hh_b = _res_4hh_b.half() _res_h4h_w = _res_h4h_w.half() _res_4hh_w = _res_4hh_w.half() _res_coef = _res_coef.half() #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group) if inference: if moe: ep_world_size = dist.get_world_size() local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else 1e-12, fp16=fp16, pre_layer_norm=policy.pre_attn_norm, mp_size=mp_size, q_int8=quantize, moe_experts=local_ep_size, global_experts=num_experts, mlp_type=moe_type) else: rotary_dim = config.rotary_dim if hasattr(config, 'rotary_dim') else child.attention.rotary_ndims \ if hasattr(child, 'attention') and hasattr(child.attention,'rotary_ndims') else -1 bigscience_bloom = policy_cls is BLOOMLayerPolicy transformer_config = transformer_inference.DeepSpeedInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else (config.layer_norm_epsilon if hasattr(config, 'layer_norm_epsilon') else config.layernorm_epsilon if hasattr(config, 'layernorm_epsilon') else 1.0e-12), fp16=fp16, pre_layer_norm=policy.pre_attn_norm, mp_size=mp_size, q_int8=quantize, return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)), triangular_masking=(policy_cls is not HFBertLayerPolicy), local_attention=((config.attention_layers[layer_id] == "local") if hasattr(config, 'attention_layers') else False), window_size=(config.window_size if hasattr(config, 'window_size') else 1), rotary_dim=rotary_dim, mlp_after_attn=(rotary_dim is None or rotary_dim < 0), mlp_act_func_type=policy.mlp_act_func_type, training_mp_size=training_mp_size, bigscience_bloom=bigscience_bloom) if quantize and quantize_settings is not None: (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups) = quantize_settings if moe: new_module = transformer_inference.DeepSpeedMoEInference( transformer_config, mp_group=mp_group, ep_group=None if ep_group is None else ep_group[num_experts], expert_mp_group=None if expert_mp_group is None else expert_mp_group[num_experts], quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) else: new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) if quantize and qkvw.dtype != torch.int8: quantize_bits = 8 quantizer = WeightQuantization() if policy_cls is HFBertLayerPolicy: data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3) else: data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups) qkvw.data.copy_(data_quantized) qkvw.data = qkvw.data.to(torch.int8) else: if moe: new_module = transformer_inference.DeepSpeedMoEInference( transformer_config, mp_group=mp_group, ep_group=None if ep_group is None else ep_group[num_experts], expert_mp_group=None if expert_mp_group is None else expert_mp_group[num_experts], ) else: new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, ) new_module.config.scale_attention = scale_attention # we want the weights in [input, output] shape # linear layer is created with [input, output] shape # transpose it here to reduce inference cost! def transpose(data): # temp move to cpu to avoid requiring extra GPU memory during the reshape data = data.to('cpu') data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1)) data = data.reshape(data.shape[-1], data.shape[-2]) data.to(torch.cuda.current_device()) return data attn_block = new_module.attention mpl_block = new_module.mlp if attn_linear_layer: if qkvw.numel() == 0 or qkvw.is_meta: if qkvw.is_meta or qkvw.ds_tensor.numel( ) < attn_block.attn_qkvw.numel(): pass else: with GatheredParameters([qkvw, dense_w, qkvb, dense_b], modifier_rank=0): qkvw = transpose(qkvw.data) dense_w = transpose(dense_w.data) qkvb = qkvb.data dense_b = dense_b.data else: qkvw.data = transpose(qkvw.data) dense_w.data = transpose(dense_w.data) def _transpose(x): num_attention_heads_per_partition = transformer_config.heads // transformer_config.mp_size attention_head_size = x.shape[-1] // num_attention_heads_per_partition new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, attention_head_size) x_1 = x.view(*new_x_shape) (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1)) if len(q.shape) > 2: return torch.cat((q.reshape(q.shape[0], -1), k.reshape(q.shape[0], -1), v.reshape(q.shape[0], -1)), dim=-1).reshape(x.shape) else: return torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape) if megatron_v2: new_module.config.rotate_half = True new_module.config.rotate_every_two = False # Note: this part needs to be added for BLOOM architecture qkvw = torch.nn.parameter.Parameter(_transpose(qkvw).contiguous()) qkvb = torch.nn.parameter.Parameter(_transpose(qkvb).contiguous()) # NOTE: This part caused instability in the multi-GPU inference! # TODO: This needs to be incorporated in the kernels. #dense_b = dense_b if dense_b is None else dense_b * ( # transformer_config.training_mp_size / transformer_config.mp_size) #_4hh_b = _4hh_b * (transformer_config.training_mp_size / # transformer_config.mp_size) if mlp_linear_layer: if not moe and (_4hh_w.numel() == 0 or _4hh_w.is_meta): if _4hh_w.is_meta or _4hh_w.ds_tensor.numel( ) < mpl_block.inter_w.numel(): pass else: with GatheredParameters([_h4h_w, _4hh_w, _4hh_b, _h4h_b], modifier_rank=0): _h4h_w = transpose(_h4h_w.data) _4hh_w = transpose(_4hh_w.data) _h4h_b = _h4h_b.data _4hh_b = _4hh_b.data else: _h4h_w = [transpose(moe_w1.data) for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data) _4hh_w = [transpose(moe_w1.data) for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data) if moe and moe_type == 'residual': _res_h4h_w.data = transpose(_res_h4h_w.data) _res_4hh_w.data = transpose(_res_4hh_w.data) _res_coef.data = transpose(_res_coef.data) if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta: if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel(): pass else: with GatheredParameters([ attn_block.attn_qkvw, attn_block.attn_qkvb, attn_block.attn_ow, attn_block.attn_ob ], modifier_rank=0): attn_block.attn_qkvw = mp_replace.copy( attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.copy( attn_block.attn_qkvb, qkvb) attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) else: if bigscience_bloom: attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb) else: attn_block.attn_qkvw = mp_replace.qkv_copy( attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.qkv_copy( attn_block.attn_qkvb, qkvb) attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) if moe: gpu_index = dist.get_rank() gpu_index = 0 for ep_index in range(local_ep_size): mpl_block[ep_index].inter_w.data = _h4h_w[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].inter_b.data = _h4h_b[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].output_w.data = _4hh_w[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].output_b.data = _4hh_b[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) new_module.attn_nw.data = attn_nw.to(torch.cuda.current_device()) new_module.attn_nb.data = attn_nb.to(torch.cuda.current_device()) if moe_type == 'residual': new_module.res_mlp.inter_w.data = _res_h4h_w.to( torch.cuda.current_device()) new_module.res_mlp.inter_b.data = _res_h4h_b.to( torch.cuda.current_device()) new_module.res_mlp.output_w.data = _res_4hh_w.to( torch.cuda.current_device()) new_module.res_mlp.output_b.data = _res_4hh_b.to( torch.cuda.current_device()) new_module.res_coef.data = _res_coef.to(torch.cuda.current_device()) else: if _4hh_w.numel() == 0 or _4hh_w.is_meta: if _4hh_w.is_meta or _4hh_w.ds_tensor.numel( ) < mpl_block.inter_w.numel(): pass else: with GatheredParameters([_h4h_w, _4hh_w, _4hh_w, _4hh_b], modifier_rank=0): mpl_block.inter_w = mp_replace.copy( mpl_block.inter_w, _h4h_w) mpl_block.inter_b = mp_replace.copy( mpl_block.inter_b, _h4h_b) mpl_block.output_w = mp_replace.copy( mpl_block.output_w, _4hh_w) mpl_block.output_b = mp_replace.copy( mpl_block.output_b, _4hh_b) else: mpl_block.inter_w = mp_replace.copy(mpl_block.inter_w, _h4h_w) mpl_block.inter_b = mp_replace.copy(mpl_block.inter_b, _h4h_b) mpl_block.output_w = mp_replace.copy(mpl_block.output_w, _4hh_w) mpl_block.output_b = mp_replace.copy(mpl_block.output_b, _4hh_b) if attn_nw is None: new_module.mlp.attn_nw = attn_nw new_module.mlp.attn_nb = attn_nb else: if attn_nw.is_meta or attn_nw.numel() == 0: if attn_nw.is_meta or attn_nw.ds_tensor.numel( ) < new_module.mlp.attn_nw.numel(): pass else: with GatheredParameters([attn_nw, attn_nb], modifier_rank=0): new_module.mlp.attn_nw.data.copy_( attn_nw.to(torch.cuda.current_device())) new_module.mlp.attn_nb.data.copy_( attn_nb.to(torch.cuda.current_device())) else: new_module.mlp.attn_nw.data.copy_( attn_nw.to(torch.cuda.current_device())) new_module.mlp.attn_nb.data.copy_( attn_nb.to(torch.cuda.current_device())) if input_nw.is_meta or input_nw.numel() == 0: if input_nw.is_meta or input_nw.ds_tensor.numel( ) < new_module.norm_w.numel(): pass else: with GatheredParameters([input_nw, input_nb], modifier_rank=0): new_module.norm_w.data.copy_( input_nw.to(torch.cuda.current_device())) new_module.norm_b.data.copy_( input_nb.to(torch.cuda.current_device())) else: new_module.norm_w.data.copy_(input_nw.to(torch.cuda.current_device())) new_module.norm_b.data.copy_(input_nb.to(torch.cuda.current_device())) else: transformer_config = deepspeed.DeepSpeedTransformerConfig( batch_size=micro_batch_size if micro_batch_size > 0 else 1, hidden_size=config.hidden_size, heads=config.num_attention_heads, attn_dropout_ratio=config.attention_probs_dropout_prob, hidden_dropout_ratio=config.hidden_dropout_prob, num_hidden_layers=config.num_hidden_layers, initializer_range=config.initializer_range, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else 1e-12, seed=seed, fp16=fp16, pre_layer_norm=policy.pre_attn_norm, return_tuple=return_tuple, local_rank=local_rank, stochastic_mode=stochastic_mode, normalize_invertible=True, training=training) new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config) new_module.attn_qkvw.data = qkvw new_module.attn_qkvb.data = qkvb new_module.attn_ow.data = dense_w new_module.attn_ob.data = dense_b new_module.attn_nw.data = attn_nw new_module.attn_nb.data = attn_nb new_module.norm_w.data = input_nw new_module.norm_b.data = input_nb new_module.inter_w.data = _h4h_w new_module.inter_b.data = _h4h_b new_module.output_w.data = _4hh_w new_module.output_b.data = _4hh_b return new_module def replace_wo_policy(module, all_reduce_linears): def _replace(child, name, conv_linear_layer): mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) z_inference = (len(list(child.parameters())) > 0) and (list( child.parameters())[0].numel() == 0) if z_inference: weight_shape = child.weight.ds_shape else: weight_shape = child.weight.shape if name in all_reduce_linears: new_weight = torch.empty(( weight_shape[1] if conv_linear_layer else weight_shape[0], (weight_shape[0] if conv_linear_layer else weight_shape[1]) // mp_size, ), device=child.weight.device, dtype=child.weight.dtype) if z_inference: with deepspeed.zero.GatheredParameters(child.weight, modifier_rank=0): data = child.weight.data.to(new_weight.device) if conv_linear_layer: data = data.transpose(-1, -2).contiguous() data = mp_replace.copy(new_weight, data) child.weight.ds_tensor = torch.empty(1) else: if conv_linear_layer: child.weight.data = child.weight.data.transpose(-1, -2).contiguous() data = mp_replace.copy(new_weight, child.weight.data) new_bias = torch.empty((weight_shape[0]), device=child.weight.device, dtype=child.weight.dtype) if z_inference: with deepspeed.zero.GatheredParameters(child.bias, modifier_rank=0): new_bias.data.copy_(child.bias.data) elif child.bias is not None: new_bias.data.copy_(child.bias.data) return LinearAllreduce(data, child.bias if child.bias is None else \ torch.nn.parameter.Parameter(new_bias.to(torch.cuda.current_device())), mp_group) else: new_weight = torch.empty(( (weight_shape[1] if conv_linear_layer else weight_shape[0]) // mp_size, weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1], ), device=child.weight.device, dtype=child.weight.dtype) if z_inference: with deepspeed.zero.GatheredParameters(child.weight, modifier_rank=0): data = child.weight.data.to(new_weight.device) if conv_linear_layer: data = data.transpose(-1, -2).contiguous() data = mp_replace.copy(new_weight, data) child.weight.ds_tensor = torch.empty(1) else: if conv_linear_layer: child.weight.data = child.weight.data.transpose(-1, -2).contiguous() data = mp_replace.copy(new_weight, child.weight.data) new_bias = torch.empty((weight_shape[0] // mp_size), device=child.weight.device, dtype=child.weight.dtype) if z_inference: with deepspeed.zero.GatheredParameters(child.bias, modifier_rank=0): bias_data = None if child.bias is None else mp_replace.copy( new_bias, child.bias.data).to(torch.cuda.current_device()) else: bias_data = None if child.bias is None else mp_replace.copy( new_bias, child.bias.data).to(torch.cuda.current_device()) return LinearLayer(weight=data.to(torch.cuda.current_device()), bias=bias_data) def _slice_embedding(child, name, conv_linear_layer): mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) new_weight = torch.empty((child.weight.shape[0], child.weight.shape[1] // mp_size), device=child.weight.device, dtype=child.weight.dtype) data = mp_replace.copy(new_weight, child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \ child.weight.data) new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // mp_size) new_embedding.weight.data.copy_(data) return new_embedding def update_mp_params(child): if hasattr(child, 'n_heads'): child.n_heads = child.n_heads // mp_size if hasattr(child, 'inner_dim'): child.inner_dim = child.inner_dim // mp_size if hasattr(child, 'num_heads'): child.num_heads = child.num_heads // mp_size if hasattr(child, 'num_attention_heads'): child.num_attention_heads = child.num_attention_heads // mp_size if hasattr(child, 'all_head_size'): child.all_head_size = child.all_head_size // mp_size if hasattr(child, 'embed_dim'): child.embed_dim = child.embed_dim // mp_size if hasattr(child, 'hidden_size'): child.hidden_size = child.hidden_size // mp_size conv_linear_layer = False if linear_layer_setting is not None: linear_policies = {linear_layer_setting[0]: _replace} if len(linear_layer_setting) == 2: linear_policies.update({linear_layer_setting[1]: _slice_embedding}) else: if orig_layer_impl is HFGPT2LayerPolicy._orig_layer_class: try: import transformers conv_linear_layer = True linear_policies = {transformers.model_utils.Conv1D: _replace} except ImportError: linear_policies = {nn.Linear: _replace} else: linear_policies = {nn.Linear: _replace, nn.Embedding: _slice_embedding} def _replace_module(r_module, prev_name=''): for name, child in r_module.named_children(): if child.__class__ in linear_policies: setattr( r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name, conv_linear_layer)) else: update_mp_params(child) _replace_module(child, name) return r_module return _replace_module(module) def replace_fn(child, _policy, layer_id=0): if training: # copy relevant state from child -> new module new_module = replace_with_policy(child, _policy, triangular_masking) else: # copy relevant state from child -> new module if replace_with_kernel_inject: new_module = replace_with_policy(child, _policy, triangular_masking, inference=True, layer_id=layer_id) else: new_module = replace_wo_policy(child, _policy) return new_module replaced_module = replace_module(model=model, orig_class=orig_layer_impl, replace_fn=replace_fn, _replace_policy=policy) world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 if checkpoint_dict is not None: start_time = time.time() checkpoint = checkpoint_dict['checkpoints'] ckpt_type = checkpoint_dict.get('parallelization', 'pp') ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size) base_dir = checkpoint_dict.get('base_dir', '') if ckpt_type == 'pp': pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards") for i in range(len(checkpoint)): if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: pbar.update(1) sd = torch.load(checkpoint[i], map_location='cpu') load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type) else: num_checkpoints = len(checkpoint) // ckpt_mp_size assert world_size >= ckpt_mp_size,\ "Currently, merging checkpoints is not supported (when world_size is smaller than #checkpoints)!" checkpoint_stride = world_size // ckpt_mp_size if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: pbar = tqdm.tqdm(total=num_checkpoints, desc=f"Loading {num_checkpoints} checkpoint shards") for i in range(num_checkpoints): if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: pbar.update(1) ckpt_index = i * ckpt_mp_size + (rank // checkpoint_stride) ckpt_file = os.path.join( base_dir, checkpoint[ckpt_index]) if base_dir else checkpoint[ckpt_index] sd = torch.load(ckpt_file, map_location='cpu') load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type, rank % (world_size // ckpt_mp_size)) print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") if save_mp_checkpoint_path is not None: from collections import OrderedDict import json if checkpoint_dict is None: ckpt_name = "ds_model" try: from transformers.models.bloom.modeling_bloom import BloomForCausalLM if isinstance(model, BloomForCausalLM): ckpt_name = "bloom" except ImportError: ckpt_name = "ds_model" else: ckpt_name = checkpoint_dict['type'] if dist.is_initialized(): dist.barrier() transformer_name = get_transformer_name(replaced_module) non_tp_ckpt_name = f'{ckpt_name}-non-tp.pt' ckpt_files = [non_tp_ckpt_name] * world_size os.makedirs(save_mp_checkpoint_path, exist_ok=True) if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") torch.save( OrderedDict({ k: v for k, v in dict(replaced_module.state_dict()).items() if transformer_name not in k }), f'{save_mp_checkpoint_path}/{non_tp_ckpt_name}') ckpt_files += [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)] config = json.dumps({ 'type': ckpt_name, 'base_dir': f'{save_mp_checkpoint_path}', 'checkpoints': ckpt_files, 'version': 1.0, 'parallelization': 'tp', 'mp_size': world_size }) with open(f"{save_mp_checkpoint_path}/{ckpt_name}_ds-inference_config.json", "w") as cfg: cfg.write(config) torch.save( OrderedDict({ k: v for k, v in dict(replaced_module.state_dict()).items() if transformer_name in k }), f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{rank:0>2d}.pt') return replaced_module
def __init__(self, params, deepspeed=None, lr=1e-3, freeze_step=100000, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt=False, weight_decay=0., max_grad_norm=0., max_coeff=10.0, min_coeff=0.01, amsgrad=False, cuda_aware=False, comm_backend_name='nccl', coeff_beta=0.9, factor_max=4.0, factor_min=0.5, factor_threshold=0.1): if amsgrad: raise RuntimeError( '1-bit Lamb does not support the AMSGrad variant.') defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, max_grad_norm=max_grad_norm, max_coeff=max_coeff, min_coeff=min_coeff) super(OnebitLamb, self).__init__(params, defaults) self.eps_mode = 0 if eps_inside_sqrt else 1 assert (dist.is_initialized()) self.deepspeed = deepspeed self.lamb_freeze_key = False self.initialize = False self.freeze_step = freeze_step self.cuda_aware = cuda_aware self.coeff_beta = coeff_beta self.factor_max = factor_max self.factor_min = factor_min self.factor_threshold = factor_threshold self.using_pipeline = False self.comm_backend_name = comm_backend_name # Empty initializer. Set handle based on the comm backend as follows. self.comm_backend_handle = None if self.comm_backend_name == 'nccl': TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" assert dist.is_initialized( ) == True, "Please initialize the torch distributed backend." from deepspeed.runtime.comm.nccl import NcclBackend self.using_pipeline = hasattr( self.deepspeed, 'pipeline_enable_backward_allreduce') self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) elif self.comm_backend_name == 'mpi': from deepspeed.runtime.comm.mpi import MpiBackend self.comm_backend_handle = MpiBackend(cuda_aware) self.size = self.comm_backend_handle.size self.divider = int(self.size * 8 / np.gcd(self.size, 8)) self.exp_avg_flat = [] self.dummy_exp_avg = {} self.corrected_tensor_sizes = [] self.server_chunk_sizes = [] self.worker_errors = [] self.server_errors = [] self.lamb_coeffs = []
def test_scattered_init_dist(): setup_serial_env() assert not dist.is_initialized() with deepspeed.zero.Init(): assert dist.is_initialized()
def compute_attention(qkv_out, input_mask): no_masking = input_mask is None head_size = (qkv_out.shape[-1] // 3 // num_attention_heads_per_partition) if no_masking: input_mask = torch.empty(1) if merge_count > 0 and config.q_int8: split_dim = (qkv_out.dim() - 1) qkv_split = torch.split(qkv_out, (qkv_out.shape[-1] // (2**merge_count)), dim=split_dim) qkv_split = [ torch.split(s, (s.shape[-1] // 3), dim=split_dim) for s in qkv_split ] (mixed_query, key_layer, value_layer) = [ torch.cat([s[i] for s in qkv_split], axis=-1) for i in range(len(qkv_split[0])) ] if config.rotary_dim > 0: mixed_query, key_layer = inference_cuda_module.apply_rotary_pos_emb( mixed_query, key_layer, config.rotary_dim, 0 if layer_past is None else layer_past[0].shape[-2], num_attention_heads_per_partition, config.rotate_half, config.rotate_every_two) if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat( (past_key.type_as(key_layer), key_layer), dim=-2) value_layer = torch.cat( (past_value.type_as(value_layer), value_layer), dim=-2) presents = (key_layer, value_layer) mixed_query = _transpose_for_scores(mixed_query, False, True) key_layer = _transpose_for_scores(key_layer, True, True) / ( norm_factor if config.scale_attention else 1.0) value_layer = _transpose_for_scores(value_layer, False, True) if layer_past is None: attn_key_value = score_context_func( mixed_query, key_layer, torch.empty(1), ((1 - input_mask).half() * minus_inf) if input_mask.dtype == torch.int64 else input_mask, value_layer, torch.empty(1), num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), (not unfused_mode), # noqa: F821 config.triangular_masking, config.local_attention, config.window_size, no_masking) else: attn_key_value = score_context_func( mixed_query, (key_layer if unfused_mode else past_key.type_as(key_layer)), # noqa: F821 key_layer, ((1 - input_mask).half() * minus_inf) if input_mask.dtype == torch.int64 else input_mask, (value_layer if unfused_mode else past_value.type_as(value_layer)), # noqa: F821 value_layer, num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), (not unfused_mode), # noqa: F821 config.triangular_masking, config.local_attention, config.window_size, no_masking) if unfused_mode: # noqa: F821 context_layer, _, _ = attn_key_value else: context_layer, key_layer, value_layer = attn_key_value # Transpose Context context_layer = _transpose_for_context(context_layer) return context_layer, presents[0], presents[ 1] # atten_output, key_layer, value_layer else: # Note: This modification is added for the BLOOM-176B model and will be removed later! if config.bigscience_bloom: context_layer, presents = backup_attention( qkv_out, layer_past, alibi, input_mask, norm_factor) return context_layer, presents[0], presents[ 1] #key_layer, value_layer else: if alibi is not None: batch_heads = qkv_out.shape[ 0] * num_attention_heads_per_partition offset = dist.get_rank( ) * batch_heads if dist.is_initialized() else 0 sliced_alibi = alibi[offset:batch_heads + offset, :, :] attn_key_value = score_context_func( qkv_out, ((1 - input_mask).to(qkv_out.dype) * minus_inf) if input_mask.dtype == torch.int64 else input_mask, config.rotary_dim, config.rotate_half, config.rotate_every_two, num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), config.triangular_masking, config.local_attention, config.window_size, no_masking, config.layer_id, DeepSpeedTransformerInference.layer_id, sliced_alibi if alibi is not None else torch.empty(1)) context_layer, key_layer, value_layer = attn_key_value return context_layer, key_layer, value_layer
def test(self): assert dist.is_initialized() assert dist.get_world_size() == 3 assert dist.get_rank() < 3