def replace_fn(child): transformer_config = deepspeed.DeepSpeedTransformerConfig( batch_size=micro_batch_size, hidden_size=bert_config.hidden_size, heads=bert_config.num_attention_heads, attn_dropout_ratio=bert_config.attention_probs_dropout_prob, hidden_dropout_ratio=bert_config.hidden_dropout_prob, num_hidden_layers=bert_config.num_hidden_layers, initializer_range=bert_config.initializer_range, layer_norm_eps=bert_config.layer_norm_eps, seed=seed, fp16=fp16, pre_layer_norm=preln, huggingface=huggingface, local_rank=local_rank, training=training) new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config) # copy relevant state from child -> new module qw = child.attention.self.query.weight qb = child.attention.self.query.bias kw = child.attention.self.key.weight kb = child.attention.self.key.bias vw = child.attention.self.value.weight vb = child.attention.self.value.bias qkvw = torch.cat((qw, kw, vw), 0) qkvb = torch.cat((qb, kb, vb), 0) #qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0) #qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0) new_module.attn_qkvw.data = qkvw new_module.attn_qkvb.data = qkvb new_module.attn_ow.data = child.attention.output.dense.weight new_module.attn_ob.data = child.attention.output.dense.bias if preln: attention_layernorm = child.PostAttentionLayerNorm else: attention_layernorm = child.attention.output.LayerNorm new_module.attn_nw.data = attention_layernorm.weight new_module.attn_nb.data = attention_layernorm.bias if preln: intermediate_ff = child.intermediate.dense_act else: intermediate_ff = child.intermediate.dense new_module.inter_w.data = intermediate_ff.weight new_module.inter_b.data = intermediate_ff.bias new_module.output_w.data = child.output.dense.weight new_module.output_b.data = child.output.dense.bias if preln: transformer_layernorm = child.PreAttentionLayerNorm else: transformer_layernorm = child.output.LayerNorm new_module.norm_w.data = transformer_layernorm.weight new_module.norm_b.data = transformer_layernorm.bias return new_module
def replace_fn(child, layer_id=0): if training: transformer_config = deepspeed.DeepSpeedTransformerConfig( batch_size=micro_batch_size, hidden_size=bert_config.hidden_size, heads=bert_config.num_attention_heads, attn_dropout_ratio=bert_config.attention_probs_dropout_prob, hidden_dropout_ratio=bert_config.hidden_dropout_prob, num_hidden_layers=bert_config.num_hidden_layers, initializer_range=bert_config.initializer_range, seed=seed, fp16=fp16, pre_layer_norm=preln, huggingface=encoder_decoder, local_rank=local_rank, training=training) new_module = deepspeed.DeepSpeedTransformerLayer( transformer_config) # copy relevant state from child -> new module replace_with_policy(new_module, child, policy, preln=preln) else: transformer_config = transformer_inference.DeepSpeedInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, fp16=fp16, pre_layer_norm=preln, mp_size=mp_size, q_int8=quantize, encoder_decoder=encoder_decoder) if quantize and quantize_settings is not None: (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups) = quantize_settings new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping) else: new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, ) # copy relevant state from child -> new module replace_with_policy(new_module, child, policy, inference=True, preln=preln) return new_module
def replace_with_policy(child, policy_cls, inference=False, preln=True, layer_id=0): preln = False if policy_cls is HFBertLayerPolicy else preln if policy_cls is HFBertLayerPolicy: policy = policy_cls(child, inference=inference, preln=preln) else: policy = policy_cls(child, inference=inference) if inference: hidden_size, num_attention_heads = policy.get_hidden_heads() assert num_attention_heads % mp_size == 0,\ "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\ "This is because the attention computation is partitioned evenly among the parallel GPUs." attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention = policy.attention( ) mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp() attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm() if quantize: if policy_cls is not HFBertLayerPolicy: qkvw = qkvw.to(torch.int8) dense_w = dense_w.to(torch.int8) _h4h_w = _h4h_w.to(torch.int8) _4hh_w = _4hh_w.to(torch.int8) elif fp16: qkvw = qkvw.half() dense_w = dense_w.half() _h4h_w = _h4h_w.half() _4hh_w = _4hh_w.half() if quantize or fp16: dense_b = dense_b.half() _h4h_b = _h4h_b.half() _4hh_b = _4hh_b.half() attn_nw = attn_nw.half() attn_nb = attn_nb.half() input_nw = input_nw.half() input_nb = input_nb.half() mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) if inference: transformer_config = transformer_inference.DeepSpeedInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else 1e-12, fp16=fp16, pre_layer_norm=preln, mp_size=mp_size, q_int8=quantize, encoder_decoder=(True if policy_cls is HFBertLayerPolicy else False), triangular_masking=(policy_cls is not HFBertLayerPolicy), local_attention=((config.attention_layers[layer_id] == "local") if hasattr( config, 'attention_layers') else False), window_size=(config.window_size if hasattr( config, 'window_size') else 1)) if quantize and quantize_settings is not None: (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups) = quantize_settings new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) if quantize and qkvw.dtype != torch.int8: quantize_bits = 8 quantizer = WeightQuantization() if policy_cls is HFBertLayerPolicy: data_quantized, _ = quantizer.quantize_data( qkvw, quantize_bits, quantize_groups * 3) else: data_quantized, _ = quantizer.quantize_data( qkvw, quantize_bits, quantize_groups) qkvw.copy_(data_quantized) qkvw = qkvw.to(torch.int8) else: new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, ) new_module.config.scale_attention = scale_attention # we want the weights in [input, output] shape # linear layer is created with [input, output] shape # transpose it here to reduce inference cost! def transpose(data): data.view(-1).copy_( data.transpose(-1, -2).contiguous().view(-1)) data = data.reshape(data.shape[-1], data.shape[-2]) return data if attn_linear_layer: qkvw = transpose(qkvw.data) dense_w = transpose(dense_w) if mlp_linear_layer: _h4h_w = transpose(_h4h_w) _4hh_w = transpose(_4hh_w) attn_block = new_module.attention attn_block.attn_qkvw.data = mp_replace.qkv_copy( attn_block.attn_qkvw.data, qkvw) if qkvb is not None: if fp16: qkvb = qkvb.half() attn_block.attn_qkvb.data = mp_replace.qkv_copy( attn_block.attn_qkvb.data, qkvb) else: attn_block.attn_qkvb = qkvb attn_block.attn_ow.data = mp_replace.copy(attn_block.attn_ow.data, dense_w) attn_block.attn_ob.data = mp_replace.copy(attn_block.attn_ob.data, dense_b) mpl_block = new_module.mlp mpl_block.inter_w.data = mp_replace.copy(mpl_block.inter_w.data, _h4h_w) mpl_block.inter_b.data = mp_replace.copy(mpl_block.inter_b.data, _h4h_b) mpl_block.output_w.data = mp_replace.copy(mpl_block.output_w.data, _4hh_w) mpl_block.output_b.data = mp_replace.copy(mpl_block.output_b.data, _4hh_b) new_module.mlp.attn_nw.data = attn_nw.to( torch.cuda.current_device()) new_module.mlp.attn_nb.data = attn_nb.to( torch.cuda.current_device()) new_module.norm_w.data = input_nw.to(torch.cuda.current_device()) new_module.norm_b.data = input_nb.to(torch.cuda.current_device()) else: transformer_config = deepspeed.DeepSpeedTransformerConfig( batch_size=micro_batch_size, hidden_size=config.hidden_size, heads=config.num_attention_heads, attn_dropout_ratio=config.attention_probs_dropout_prob, hidden_dropout_ratio=config.hidden_dropout_prob, num_hidden_layers=config.num_hidden_layers, initializer_range=config.initializer_range, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else 1e-12, seed=seed, fp16=fp16, pre_layer_norm=(False if policy_cls is HFBertLayerPolicy else preln), huggingface=encoder_decoder, local_rank=local_rank, stochastic_mode=stochastic_mode, normalize_invertible=True, training=training) new_module = deepspeed.DeepSpeedTransformerLayer( transformer_config) new_module.attn_qkvw.data = qkvw new_module.attn_qkvb.data = qkvb new_module.attn_ow.data = dense_w new_module.attn_ob.data = dense_b new_module.attn_nw.data = attn_nw new_module.attn_nb.data = attn_nb new_module.norm_w.data = input_nw new_module.norm_b.data = input_nb new_module.inter_w.data = _h4h_w new_module.inter_b.data = _h4h_b new_module.output_w.data = _4hh_w new_module.output_b.data = _4hh_b return new_module
def replace_with_policy(child, policy_cls, triangular_masking, inference=False, preln=True, layer_id=0): preln = False if policy_cls is HFBertLayerPolicy else preln if policy_cls is HFBertLayerPolicy: policy = policy_cls(child, inference=inference, preln=preln) else: policy = policy_cls(child, inference=inference) if inference: hidden_size, num_attention_heads = policy.get_hidden_heads() assert num_attention_heads % mp_size == 0,\ "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\ "This is because the attention computation is partitioned evenly among the parallel GPUs." from deepspeed.moe.layer import MoE moe = False if isinstance(child.mlp, MoE): num_experts = child.mlp.num_experts moe = True attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention = policy.attention() if not moe or moe_type == 'standard': mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp() else: mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b, \ _res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type) attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm() if quantize: if policy_cls is not HFBertLayerPolicy: qkvw = qkvw.to(torch.int8) dense_w = dense_w.to(torch.int8) _h4h_w = [moe_w1.to(torch.int8) for moe_w1 in _h4h_w] if moe else _h4h_w.to(torch.int8) _4hh_w = [moe_w1.to(torch.int8) for moe_w1 in _4hh_w] if moe else _4hh_w.to(torch.int8) elif fp16: qkvw = qkvw.half() dense_w = dense_w.half() _h4h_w = [moe_w1.half() for moe_w1 in _h4h_w] if moe else _h4h_w.half() _4hh_w = [moe_w1.half() for moe_w1 in _4hh_w] if moe else _4hh_w.half() if quantize or fp16: qkvb = qkvb if qkvb is None else qkvb.half() dense_b = dense_b if dense_b is None else dense_b.half() _h4h_b = [moe_b1.half() for moe_b1 in _h4h_b] if moe else _h4h_b.half() _4hh_b = [moe_b1.half() for moe_b1 in _4hh_b] if moe else _4hh_b.half() attn_nw = attn_nw if attn_nw is None else attn_nw.half() attn_nb = attn_nb if attn_nb is None else attn_nb.half() input_nw = input_nw.half() input_nb = input_nb.half() if moe and moe_type == 'residual' and fp16: _res_h4h_b = _res_h4h_b.half() _res_4hh_b = _res_4hh_b.half() _res_h4h_w = _res_h4h_w.half() _res_4hh_w = _res_4hh_w.half() _res_coef = _res_coef.half() mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group) if inference: if moe: ep_world_size = torch.distributed.get_world_size() local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else 1e-12, fp16=fp16, pre_layer_norm=preln, mp_size=mp_size, q_int8=quantize, moe_experts=local_ep_size, global_experts=num_experts, mlp_type=moe_type) else: transformer_config = transformer_inference.DeepSpeedInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else (config.layer_norm_epsilon if hasattr( config, 'layer_norm_epsilon') else 1e-12), fp16=fp16, pre_layer_norm=preln, mp_size=mp_size, q_int8=quantize, return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)), triangular_masking=(policy_cls is not HFBertLayerPolicy), local_attention=((config.attention_layers[layer_id] == "local") if hasattr(config, 'attention_layers') else False), window_size=(config.window_size if hasattr(config, 'window_size') else 1), rotary_dim=(config.rotary_dim if hasattr(config, 'rotary_dim') else -1), mlp_after_attn=(policy_cls is not HFGPTJLayerPolicy)) if quantize and quantize_settings is not None: (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups) = quantize_settings if moe: new_module = transformer_inference.DeepSpeedMoEInference( transformer_config, mp_group=mp_group, ep_group=None if ep_group is None else ep_group[num_experts], expert_mp_group=None if expert_mp_group is None else expert_mp_group[num_experts], quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) else: new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) if quantize and qkvw.dtype != torch.int8: quantize_bits = 8 quantizer = WeightQuantization() if policy_cls is HFBertLayerPolicy: data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3) else: data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups) qkvw.data.copy_(data_quantized) qkvw.data = qkvw.data.to(torch.int8) else: if moe: new_module = transformer_inference.DeepSpeedMoEInference( transformer_config, mp_group=mp_group, ep_group=None if ep_group is None else ep_group[num_experts], expert_mp_group=None if expert_mp_group is None else expert_mp_group[num_experts], ) else: new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, ) new_module.config.scale_attention = scale_attention # we want the weights in [input, output] shape # linear layer is created with [input, output] shape # transpose it here to reduce inference cost! def transpose(data): data.view(-1).copy_(data.transpose(-1, -2).contiguous().view(-1)) data = data.reshape(data.shape[-1], data.shape[-2]) return data if attn_linear_layer: qkvw.data = transpose(qkvw.data) dense_w.data = transpose(dense_w.data) if mlp_linear_layer: _h4h_w = [transpose(moe_w1.data) for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data) _4hh_w = [transpose(moe_w1.data) for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data) if moe and moe_type == 'residual': _res_h4h_w.data = transpose(_res_h4h_w.data) _res_4hh_w.data = transpose(_res_4hh_w.data) _res_coef.data = transpose(_res_coef.data) attn_block = new_module.attention attn_block.attn_qkvw = mp_replace.qkv_copy(attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb, qkvb) attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) mpl_block = new_module.mlp if moe: gpu_index = torch.distributed.get_rank() gpu_index = 0 for ep_index in range(local_ep_size): mpl_block[ep_index].inter_w.data = _h4h_w[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].inter_b.data = _h4h_b[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].output_w.data = _4hh_w[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].output_b.data = _4hh_b[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) new_module.attn_nw.data = attn_nw.to(torch.cuda.current_device()) new_module.attn_nb.data = attn_nb.to(torch.cuda.current_device()) if moe_type == 'residual': new_module.res_mlp.inter_w.data = _res_h4h_w.to( torch.cuda.current_device()) new_module.res_mlp.inter_b.data = _res_h4h_b.to( torch.cuda.current_device()) new_module.res_mlp.output_w.data = _res_4hh_w.to( torch.cuda.current_device()) new_module.res_mlp.output_b.data = _res_4hh_b.to( torch.cuda.current_device()) new_module.res_coef.data = _res_coef.to(torch.cuda.current_device()) else: mpl_block.inter_w.data = mp_replace.copy(mpl_block.inter_w, _h4h_w) mpl_block.inter_b.data = mp_replace.copy(mpl_block.inter_b, _h4h_b) mpl_block.output_w.data = mp_replace.copy(mpl_block.output_w, _4hh_w) mpl_block.output_b.data = mp_replace.copy(mpl_block.output_b, _4hh_b) if attn_nw is None: new_module.mlp.attn_nw = attn_nw else: new_module.mlp.attn_nw.data = attn_nw.to(torch.cuda.current_device()) if attn_nb is None: new_module.mlp.attn_nb = attn_nb else: new_module.mlp.attn_nb.data = attn_nb.to(torch.cuda.current_device()) new_module.norm_w.data = input_nw.to(torch.cuda.current_device()) new_module.norm_b.data = input_nb.to(torch.cuda.current_device()) else: transformer_config = deepspeed.DeepSpeedTransformerConfig( batch_size=micro_batch_size, hidden_size=config.hidden_size, heads=config.num_attention_heads, attn_dropout_ratio=config.attention_probs_dropout_prob, hidden_dropout_ratio=config.hidden_dropout_prob, num_hidden_layers=config.num_hidden_layers, initializer_range=config.initializer_range, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else 1e-12, seed=seed, fp16=fp16, pre_layer_norm=(False if policy_cls is HFBertLayerPolicy else preln), return_tuple=return_tuple, local_rank=local_rank, stochastic_mode=stochastic_mode, normalize_invertible=True, training=training) new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config) new_module.attn_qkvw.data = qkvw new_module.attn_qkvb.data = qkvb new_module.attn_ow.data = dense_w new_module.attn_ob.data = dense_b new_module.attn_nw.data = attn_nw new_module.attn_nb.data = attn_nb new_module.norm_w.data = input_nw new_module.norm_b.data = input_nb new_module.inter_w.data = _h4h_w new_module.inter_b.data = _h4h_b new_module.output_w.data = _4hh_w new_module.output_b.data = _4hh_b return new_module
def replace_with_policy(child, policy_cls, triangular_masking, inference=False, layer_id=0): policy = policy_cls(child, inference=inference) if inference: hidden_size, num_attention_heads = policy.get_hidden_heads() assert num_attention_heads % mp_size == 0,\ "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\ "This is because the attention computation is partitioned evenly among the parallel GPUs." from deepspeed.moe.layer import MoE moe = False if hasattr(child, 'mlp') and isinstance(child.mlp, MoE): num_experts = child.mlp.num_experts moe = True attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention() if not moe or moe_type == 'standard': mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp() else: mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b, \ _res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type) attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm() if quantize: if policy_cls is not HFBertLayerPolicy: qkvw = qkvw.to(torch.int8) dense_w = dense_w.to(torch.int8) _h4h_w = [moe_w1.to(torch.int8) for moe_w1 in _h4h_w] if moe else _h4h_w.to(torch.int8) _4hh_w = [moe_w1.to(torch.int8) for moe_w1 in _4hh_w] if moe else _4hh_w.to(torch.int8) elif fp16: qkvw = qkvw.half() dense_w = dense_w.half() _h4h_w = [moe_w1.half() for moe_w1 in _h4h_w] if moe else _h4h_w.half() _4hh_w = [moe_w1.half() for moe_w1 in _4hh_w] if moe else _4hh_w.half() if quantize or fp16: qkvb = qkvb if qkvb is None else qkvb.half() dense_b = dense_b if dense_b is None else dense_b.half() _h4h_b = [moe_b1.half() for moe_b1 in _h4h_b] if moe else _h4h_b.half() _4hh_b = [moe_b1.half() for moe_b1 in _4hh_b] if moe else _4hh_b.half() attn_nw = attn_nw if attn_nw is None else attn_nw.half() attn_nb = attn_nb if attn_nb is None else attn_nb.half() input_nw = input_nw.half() input_nb = input_nb.half() if moe and moe_type == 'residual' and fp16: _res_h4h_b = _res_h4h_b.half() _res_4hh_b = _res_4hh_b.half() _res_h4h_w = _res_h4h_w.half() _res_4hh_w = _res_4hh_w.half() _res_coef = _res_coef.half() #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group) if inference: if moe: ep_world_size = dist.get_world_size() local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else 1e-12, fp16=fp16, pre_layer_norm=policy.pre_attn_norm, mp_size=mp_size, q_int8=quantize, moe_experts=local_ep_size, global_experts=num_experts, mlp_type=moe_type) else: rotary_dim = config.rotary_dim if hasattr(config, 'rotary_dim') else child.attention.rotary_ndims \ if hasattr(child, 'attention') and hasattr(child.attention,'rotary_ndims') else -1 bigscience_bloom = policy_cls is BLOOMLayerPolicy transformer_config = transformer_inference.DeepSpeedInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else (config.layer_norm_epsilon if hasattr(config, 'layer_norm_epsilon') else config.layernorm_epsilon if hasattr(config, 'layernorm_epsilon') else 1.0e-12), fp16=fp16, pre_layer_norm=policy.pre_attn_norm, mp_size=mp_size, q_int8=quantize, return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)), triangular_masking=(policy_cls is not HFBertLayerPolicy), local_attention=((config.attention_layers[layer_id] == "local") if hasattr(config, 'attention_layers') else False), window_size=(config.window_size if hasattr(config, 'window_size') else 1), rotary_dim=rotary_dim, mlp_after_attn=(rotary_dim is None or rotary_dim < 0), mlp_act_func_type=policy.mlp_act_func_type, training_mp_size=training_mp_size, bigscience_bloom=bigscience_bloom) if quantize and quantize_settings is not None: (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups) = quantize_settings if moe: new_module = transformer_inference.DeepSpeedMoEInference( transformer_config, mp_group=mp_group, ep_group=None if ep_group is None else ep_group[num_experts], expert_mp_group=None if expert_mp_group is None else expert_mp_group[num_experts], quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) else: new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, quantize_scales=quantization_scales[layer_id], quantize_groups=quantize_groups, merge_count=merge_count, mlp_extra_grouping=mlp_extra_grouping, qkv_merging=(policy_cls is HFBertLayerPolicy)) if quantize and qkvw.dtype != torch.int8: quantize_bits = 8 quantizer = WeightQuantization() if policy_cls is HFBertLayerPolicy: data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3) else: data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups) qkvw.data.copy_(data_quantized) qkvw.data = qkvw.data.to(torch.int8) else: if moe: new_module = transformer_inference.DeepSpeedMoEInference( transformer_config, mp_group=mp_group, ep_group=None if ep_group is None else ep_group[num_experts], expert_mp_group=None if expert_mp_group is None else expert_mp_group[num_experts], ) else: new_module = transformer_inference.DeepSpeedTransformerInference( transformer_config, mp_group=mp_group, ) new_module.config.scale_attention = scale_attention # we want the weights in [input, output] shape # linear layer is created with [input, output] shape # transpose it here to reduce inference cost! def transpose(data): # temp move to cpu to avoid requiring extra GPU memory during the reshape data = data.to('cpu') data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1)) data = data.reshape(data.shape[-1], data.shape[-2]) data.to(torch.cuda.current_device()) return data attn_block = new_module.attention mpl_block = new_module.mlp if attn_linear_layer: if qkvw.numel() == 0 or qkvw.is_meta: if qkvw.is_meta or qkvw.ds_tensor.numel( ) < attn_block.attn_qkvw.numel(): pass else: with GatheredParameters([qkvw, dense_w, qkvb, dense_b], modifier_rank=0): qkvw = transpose(qkvw.data) dense_w = transpose(dense_w.data) qkvb = qkvb.data dense_b = dense_b.data else: qkvw.data = transpose(qkvw.data) dense_w.data = transpose(dense_w.data) def _transpose(x): num_attention_heads_per_partition = transformer_config.heads // transformer_config.mp_size attention_head_size = x.shape[-1] // num_attention_heads_per_partition new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, attention_head_size) x_1 = x.view(*new_x_shape) (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1)) if len(q.shape) > 2: return torch.cat((q.reshape(q.shape[0], -1), k.reshape(q.shape[0], -1), v.reshape(q.shape[0], -1)), dim=-1).reshape(x.shape) else: return torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape) if megatron_v2: new_module.config.rotate_half = True new_module.config.rotate_every_two = False # Note: this part needs to be added for BLOOM architecture qkvw = torch.nn.parameter.Parameter(_transpose(qkvw).contiguous()) qkvb = torch.nn.parameter.Parameter(_transpose(qkvb).contiguous()) # NOTE: This part caused instability in the multi-GPU inference! # TODO: This needs to be incorporated in the kernels. #dense_b = dense_b if dense_b is None else dense_b * ( # transformer_config.training_mp_size / transformer_config.mp_size) #_4hh_b = _4hh_b * (transformer_config.training_mp_size / # transformer_config.mp_size) if mlp_linear_layer: if not moe and (_4hh_w.numel() == 0 or _4hh_w.is_meta): if _4hh_w.is_meta or _4hh_w.ds_tensor.numel( ) < mpl_block.inter_w.numel(): pass else: with GatheredParameters([_h4h_w, _4hh_w, _4hh_b, _h4h_b], modifier_rank=0): _h4h_w = transpose(_h4h_w.data) _4hh_w = transpose(_4hh_w.data) _h4h_b = _h4h_b.data _4hh_b = _4hh_b.data else: _h4h_w = [transpose(moe_w1.data) for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data) _4hh_w = [transpose(moe_w1.data) for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data) if moe and moe_type == 'residual': _res_h4h_w.data = transpose(_res_h4h_w.data) _res_4hh_w.data = transpose(_res_4hh_w.data) _res_coef.data = transpose(_res_coef.data) if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta: if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel(): pass else: with GatheredParameters([ attn_block.attn_qkvw, attn_block.attn_qkvb, attn_block.attn_ow, attn_block.attn_ob ], modifier_rank=0): attn_block.attn_qkvw = mp_replace.copy( attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.copy( attn_block.attn_qkvb, qkvb) attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) else: if bigscience_bloom: attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb) else: attn_block.attn_qkvw = mp_replace.qkv_copy( attn_block.attn_qkvw, qkvw) attn_block.attn_qkvb = mp_replace.qkv_copy( attn_block.attn_qkvb, qkvb) attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) if moe: gpu_index = dist.get_rank() gpu_index = 0 for ep_index in range(local_ep_size): mpl_block[ep_index].inter_w.data = _h4h_w[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].inter_b.data = _h4h_b[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].output_w.data = _4hh_w[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) mpl_block[ep_index].output_b.data = _4hh_b[ gpu_index * local_ep_size + ep_index].to( torch.cuda.current_device()) new_module.attn_nw.data = attn_nw.to(torch.cuda.current_device()) new_module.attn_nb.data = attn_nb.to(torch.cuda.current_device()) if moe_type == 'residual': new_module.res_mlp.inter_w.data = _res_h4h_w.to( torch.cuda.current_device()) new_module.res_mlp.inter_b.data = _res_h4h_b.to( torch.cuda.current_device()) new_module.res_mlp.output_w.data = _res_4hh_w.to( torch.cuda.current_device()) new_module.res_mlp.output_b.data = _res_4hh_b.to( torch.cuda.current_device()) new_module.res_coef.data = _res_coef.to(torch.cuda.current_device()) else: if _4hh_w.numel() == 0 or _4hh_w.is_meta: if _4hh_w.is_meta or _4hh_w.ds_tensor.numel( ) < mpl_block.inter_w.numel(): pass else: with GatheredParameters([_h4h_w, _4hh_w, _4hh_w, _4hh_b], modifier_rank=0): mpl_block.inter_w = mp_replace.copy( mpl_block.inter_w, _h4h_w) mpl_block.inter_b = mp_replace.copy( mpl_block.inter_b, _h4h_b) mpl_block.output_w = mp_replace.copy( mpl_block.output_w, _4hh_w) mpl_block.output_b = mp_replace.copy( mpl_block.output_b, _4hh_b) else: mpl_block.inter_w = mp_replace.copy(mpl_block.inter_w, _h4h_w) mpl_block.inter_b = mp_replace.copy(mpl_block.inter_b, _h4h_b) mpl_block.output_w = mp_replace.copy(mpl_block.output_w, _4hh_w) mpl_block.output_b = mp_replace.copy(mpl_block.output_b, _4hh_b) if attn_nw is None: new_module.mlp.attn_nw = attn_nw new_module.mlp.attn_nb = attn_nb else: if attn_nw.is_meta or attn_nw.numel() == 0: if attn_nw.is_meta or attn_nw.ds_tensor.numel( ) < new_module.mlp.attn_nw.numel(): pass else: with GatheredParameters([attn_nw, attn_nb], modifier_rank=0): new_module.mlp.attn_nw.data.copy_( attn_nw.to(torch.cuda.current_device())) new_module.mlp.attn_nb.data.copy_( attn_nb.to(torch.cuda.current_device())) else: new_module.mlp.attn_nw.data.copy_( attn_nw.to(torch.cuda.current_device())) new_module.mlp.attn_nb.data.copy_( attn_nb.to(torch.cuda.current_device())) if input_nw.is_meta or input_nw.numel() == 0: if input_nw.is_meta or input_nw.ds_tensor.numel( ) < new_module.norm_w.numel(): pass else: with GatheredParameters([input_nw, input_nb], modifier_rank=0): new_module.norm_w.data.copy_( input_nw.to(torch.cuda.current_device())) new_module.norm_b.data.copy_( input_nb.to(torch.cuda.current_device())) else: new_module.norm_w.data.copy_(input_nw.to(torch.cuda.current_device())) new_module.norm_b.data.copy_(input_nb.to(torch.cuda.current_device())) else: transformer_config = deepspeed.DeepSpeedTransformerConfig( batch_size=micro_batch_size if micro_batch_size > 0 else 1, hidden_size=config.hidden_size, heads=config.num_attention_heads, attn_dropout_ratio=config.attention_probs_dropout_prob, hidden_dropout_ratio=config.hidden_dropout_prob, num_hidden_layers=config.num_hidden_layers, initializer_range=config.initializer_range, layer_norm_eps=config.layer_norm_eps if hasattr( config, 'layer_norm_eps') else 1e-12, seed=seed, fp16=fp16, pre_layer_norm=policy.pre_attn_norm, return_tuple=return_tuple, local_rank=local_rank, stochastic_mode=stochastic_mode, normalize_invertible=True, training=training) new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config) new_module.attn_qkvw.data = qkvw new_module.attn_qkvb.data = qkvb new_module.attn_ow.data = dense_w new_module.attn_ob.data = dense_b new_module.attn_nw.data = attn_nw new_module.attn_nb.data = attn_nb new_module.norm_w.data = input_nw new_module.norm_b.data = input_nb new_module.inter_w.data = _h4h_w new_module.inter_b.data = _h4h_b new_module.output_w.data = _4hh_w new_module.output_b.data = _4hh_b return new_module