Example #1
0
    def replace_fn(child):
        transformer_config = deepspeed.DeepSpeedTransformerConfig(
            batch_size=micro_batch_size,
            hidden_size=bert_config.hidden_size,
            heads=bert_config.num_attention_heads,
            attn_dropout_ratio=bert_config.attention_probs_dropout_prob,
            hidden_dropout_ratio=bert_config.hidden_dropout_prob,
            num_hidden_layers=bert_config.num_hidden_layers,
            initializer_range=bert_config.initializer_range,
            layer_norm_eps=bert_config.layer_norm_eps,
            seed=seed,
            fp16=fp16,
            pre_layer_norm=preln,
            huggingface=huggingface,
            local_rank=local_rank,
            training=training)
        new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)

        # copy relevant state from child -> new module
        qw = child.attention.self.query.weight
        qb = child.attention.self.query.bias
        kw = child.attention.self.key.weight
        kb = child.attention.self.key.bias
        vw = child.attention.self.value.weight
        vb = child.attention.self.value.bias

        qkvw = torch.cat((qw, kw, vw), 0)
        qkvb = torch.cat((qb, kb, vb), 0)

        #qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0)
        #qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0)

        new_module.attn_qkvw.data = qkvw
        new_module.attn_qkvb.data = qkvb
        new_module.attn_ow.data = child.attention.output.dense.weight
        new_module.attn_ob.data = child.attention.output.dense.bias
        if preln:
            attention_layernorm = child.PostAttentionLayerNorm
        else:
            attention_layernorm = child.attention.output.LayerNorm
        new_module.attn_nw.data = attention_layernorm.weight
        new_module.attn_nb.data = attention_layernorm.bias
        if preln:
            intermediate_ff = child.intermediate.dense_act
        else:
            intermediate_ff = child.intermediate.dense
        new_module.inter_w.data = intermediate_ff.weight
        new_module.inter_b.data = intermediate_ff.bias
        new_module.output_w.data = child.output.dense.weight
        new_module.output_b.data = child.output.dense.bias
        if preln:
            transformer_layernorm = child.PreAttentionLayerNorm
        else:
            transformer_layernorm = child.output.LayerNorm
        new_module.norm_w.data = transformer_layernorm.weight
        new_module.norm_b.data = transformer_layernorm.bias
        return new_module
Example #2
0
    def replace_fn(child, layer_id=0):
        if training:
            transformer_config = deepspeed.DeepSpeedTransformerConfig(
                batch_size=micro_batch_size,
                hidden_size=bert_config.hidden_size,
                heads=bert_config.num_attention_heads,
                attn_dropout_ratio=bert_config.attention_probs_dropout_prob,
                hidden_dropout_ratio=bert_config.hidden_dropout_prob,
                num_hidden_layers=bert_config.num_hidden_layers,
                initializer_range=bert_config.initializer_range,
                seed=seed,
                fp16=fp16,
                pre_layer_norm=preln,
                huggingface=encoder_decoder,
                local_rank=local_rank,
                training=training)
            new_module = deepspeed.DeepSpeedTransformerLayer(
                transformer_config)

            # copy relevant state from child -> new module
            replace_with_policy(new_module, child, policy, preln=preln)

        else:
            transformer_config = transformer_inference.DeepSpeedInferenceConfig(
                hidden_size=hidden_size,
                heads=num_attention_heads,
                fp16=fp16,
                pre_layer_norm=preln,
                mp_size=mp_size,
                q_int8=quantize,
                encoder_decoder=encoder_decoder)

            if quantize and quantize_settings is not None:
                (quantization_scales, merge_count, mlp_extra_grouping,
                 quantize_groups) = quantize_settings
                new_module = transformer_inference.DeepSpeedTransformerInference(
                    transformer_config,
                    mp_group=mp_group,
                    quantize_scales=quantization_scales[layer_id],
                    quantize_groups=quantize_groups,
                    merge_count=merge_count,
                    mlp_extra_grouping=mlp_extra_grouping)
            else:
                new_module = transformer_inference.DeepSpeedTransformerInference(
                    transformer_config,
                    mp_group=mp_group,
                )

            # copy relevant state from child -> new module
            replace_with_policy(new_module,
                                child,
                                policy,
                                inference=True,
                                preln=preln)

        return new_module
Example #3
0
    def replace_with_policy(child,
                            policy_cls,
                            inference=False,
                            preln=True,
                            layer_id=0):
        preln = False if policy_cls is HFBertLayerPolicy else preln
        if policy_cls is HFBertLayerPolicy:
            policy = policy_cls(child, inference=inference, preln=preln)
        else:
            policy = policy_cls(child, inference=inference)

        if inference:
            hidden_size, num_attention_heads = policy.get_hidden_heads()
            assert num_attention_heads % mp_size == 0,\
                "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\
                "This is because the attention computation is partitioned evenly among the parallel GPUs."

        attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention = policy.attention(
        )
        mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp()
        attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm()

        if quantize:
            if policy_cls is not HFBertLayerPolicy:
                qkvw = qkvw.to(torch.int8)
            dense_w = dense_w.to(torch.int8)
            _h4h_w = _h4h_w.to(torch.int8)
            _4hh_w = _4hh_w.to(torch.int8)
        elif fp16:
            qkvw = qkvw.half()
            dense_w = dense_w.half()
            _h4h_w = _h4h_w.half()
            _4hh_w = _4hh_w.half()

        if quantize or fp16:
            dense_b = dense_b.half()
            _h4h_b = _h4h_b.half()
            _4hh_b = _4hh_b.half()
            attn_nw = attn_nw.half()
            attn_nb = attn_nb.half()
            input_nw = input_nw.half()
            input_nb = input_nb.half()

        mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)

        if inference:
            transformer_config = transformer_inference.DeepSpeedInferenceConfig(
                hidden_size=hidden_size,
                heads=num_attention_heads,
                layer_norm_eps=config.layer_norm_eps if hasattr(
                    config, 'layer_norm_eps') else 1e-12,
                fp16=fp16,
                pre_layer_norm=preln,
                mp_size=mp_size,
                q_int8=quantize,
                encoder_decoder=(True if policy_cls is HFBertLayerPolicy else
                                 False),
                triangular_masking=(policy_cls is not HFBertLayerPolicy),
                local_attention=((config.attention_layers[layer_id]
                                  == "local") if hasattr(
                                      config, 'attention_layers') else False),
                window_size=(config.window_size if hasattr(
                    config, 'window_size') else 1))

            if quantize and quantize_settings is not None:
                (quantization_scales, merge_count, mlp_extra_grouping,
                 quantize_groups) = quantize_settings
                new_module = transformer_inference.DeepSpeedTransformerInference(
                    transformer_config,
                    mp_group=mp_group,
                    quantize_scales=quantization_scales[layer_id],
                    quantize_groups=quantize_groups,
                    merge_count=merge_count,
                    mlp_extra_grouping=mlp_extra_grouping,
                    qkv_merging=(policy_cls is HFBertLayerPolicy))

                if quantize and qkvw.dtype != torch.int8:
                    quantize_bits = 8
                    quantizer = WeightQuantization()
                    if policy_cls is HFBertLayerPolicy:
                        data_quantized, _ = quantizer.quantize_data(
                            qkvw, quantize_bits, quantize_groups * 3)
                    else:
                        data_quantized, _ = quantizer.quantize_data(
                            qkvw, quantize_bits, quantize_groups)
                    qkvw.copy_(data_quantized)
                    qkvw = qkvw.to(torch.int8)
            else:
                new_module = transformer_inference.DeepSpeedTransformerInference(
                    transformer_config,
                    mp_group=mp_group,
                )
            new_module.config.scale_attention = scale_attention

            # we want the weights in [input, output] shape
            # linear layer is created with [input, output] shape
            # transpose it here to reduce inference cost!
            def transpose(data):
                data.view(-1).copy_(
                    data.transpose(-1, -2).contiguous().view(-1))
                data = data.reshape(data.shape[-1], data.shape[-2])
                return data

            if attn_linear_layer:
                qkvw = transpose(qkvw.data)
                dense_w = transpose(dense_w)

            if mlp_linear_layer:
                _h4h_w = transpose(_h4h_w)
                _4hh_w = transpose(_4hh_w)

            attn_block = new_module.attention
            attn_block.attn_qkvw.data = mp_replace.qkv_copy(
                attn_block.attn_qkvw.data, qkvw)

            if qkvb is not None:
                if fp16:
                    qkvb = qkvb.half()
                attn_block.attn_qkvb.data = mp_replace.qkv_copy(
                    attn_block.attn_qkvb.data, qkvb)
            else:
                attn_block.attn_qkvb = qkvb

            attn_block.attn_ow.data = mp_replace.copy(attn_block.attn_ow.data,
                                                      dense_w)
            attn_block.attn_ob.data = mp_replace.copy(attn_block.attn_ob.data,
                                                      dense_b)

            mpl_block = new_module.mlp
            mpl_block.inter_w.data = mp_replace.copy(mpl_block.inter_w.data,
                                                     _h4h_w)
            mpl_block.inter_b.data = mp_replace.copy(mpl_block.inter_b.data,
                                                     _h4h_b)
            mpl_block.output_w.data = mp_replace.copy(mpl_block.output_w.data,
                                                      _4hh_w)
            mpl_block.output_b.data = mp_replace.copy(mpl_block.output_b.data,
                                                      _4hh_b)

            new_module.mlp.attn_nw.data = attn_nw.to(
                torch.cuda.current_device())
            new_module.mlp.attn_nb.data = attn_nb.to(
                torch.cuda.current_device())
            new_module.norm_w.data = input_nw.to(torch.cuda.current_device())
            new_module.norm_b.data = input_nb.to(torch.cuda.current_device())
        else:
            transformer_config = deepspeed.DeepSpeedTransformerConfig(
                batch_size=micro_batch_size,
                hidden_size=config.hidden_size,
                heads=config.num_attention_heads,
                attn_dropout_ratio=config.attention_probs_dropout_prob,
                hidden_dropout_ratio=config.hidden_dropout_prob,
                num_hidden_layers=config.num_hidden_layers,
                initializer_range=config.initializer_range,
                layer_norm_eps=config.layer_norm_eps if hasattr(
                    config, 'layer_norm_eps') else 1e-12,
                seed=seed,
                fp16=fp16,
                pre_layer_norm=(False
                                if policy_cls is HFBertLayerPolicy else preln),
                huggingface=encoder_decoder,
                local_rank=local_rank,
                stochastic_mode=stochastic_mode,
                normalize_invertible=True,
                training=training)
            new_module = deepspeed.DeepSpeedTransformerLayer(
                transformer_config)
            new_module.attn_qkvw.data = qkvw
            new_module.attn_qkvb.data = qkvb
            new_module.attn_ow.data = dense_w
            new_module.attn_ob.data = dense_b

            new_module.attn_nw.data = attn_nw
            new_module.attn_nb.data = attn_nb
            new_module.norm_w.data = input_nw
            new_module.norm_b.data = input_nb

            new_module.inter_w.data = _h4h_w
            new_module.inter_b.data = _h4h_b
            new_module.output_w.data = _4hh_w
            new_module.output_b.data = _4hh_b
        return new_module
Example #4
0
    def replace_with_policy(child,
                            policy_cls,
                            triangular_masking,
                            inference=False,
                            preln=True,
                            layer_id=0):
        preln = False if policy_cls is HFBertLayerPolicy else preln
        if policy_cls is HFBertLayerPolicy:
            policy = policy_cls(child, inference=inference, preln=preln)
        else:
            policy = policy_cls(child, inference=inference)

        if inference:
            hidden_size, num_attention_heads = policy.get_hidden_heads()
            assert num_attention_heads % mp_size == 0,\
                "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\
                "This is because the attention computation is partitioned evenly among the parallel GPUs."
        from deepspeed.moe.layer import MoE
        moe = False
        if isinstance(child.mlp, MoE):
            num_experts = child.mlp.num_experts
            moe = True

        attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention = policy.attention()
        if not moe or moe_type == 'standard':
            mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp()
        else:
            mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b, \
                _res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type)

        attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm()
        if quantize:
            if policy_cls is not HFBertLayerPolicy:
                qkvw = qkvw.to(torch.int8)
            dense_w = dense_w.to(torch.int8)
            _h4h_w = [moe_w1.to(torch.int8)
                      for moe_w1 in _h4h_w] if moe else _h4h_w.to(torch.int8)
            _4hh_w = [moe_w1.to(torch.int8)
                      for moe_w1 in _4hh_w] if moe else _4hh_w.to(torch.int8)
        elif fp16:
            qkvw = qkvw.half()
            dense_w = dense_w.half()
            _h4h_w = [moe_w1.half() for moe_w1 in _h4h_w] if moe else _h4h_w.half()
            _4hh_w = [moe_w1.half() for moe_w1 in _4hh_w] if moe else _4hh_w.half()
        if quantize or fp16:
            qkvb = qkvb if qkvb is None else qkvb.half()
            dense_b = dense_b if dense_b is None else dense_b.half()
            _h4h_b = [moe_b1.half() for moe_b1 in _h4h_b] if moe else _h4h_b.half()
            _4hh_b = [moe_b1.half() for moe_b1 in _4hh_b] if moe else _4hh_b.half()
            attn_nw = attn_nw if attn_nw is None else attn_nw.half()
            attn_nb = attn_nb if attn_nb is None else attn_nb.half()
            input_nw = input_nw.half()
            input_nb = input_nb.half()

        if moe and moe_type == 'residual' and fp16:
            _res_h4h_b = _res_h4h_b.half()
            _res_4hh_b = _res_4hh_b.half()
            _res_h4h_w = _res_h4h_w.half()
            _res_4hh_w = _res_4hh_w.half()
            _res_coef = _res_coef.half()

        mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
        #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group)

        if inference:
            if moe:
                ep_world_size = torch.distributed.get_world_size()
                local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size

                transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig(
                    hidden_size=hidden_size,
                    heads=num_attention_heads,
                    layer_norm_eps=config.layer_norm_eps if hasattr(
                        config,
                        'layer_norm_eps') else 1e-12,
                    fp16=fp16,
                    pre_layer_norm=preln,
                    mp_size=mp_size,
                    q_int8=quantize,
                    moe_experts=local_ep_size,
                    global_experts=num_experts,
                    mlp_type=moe_type)
            else:
                transformer_config = transformer_inference.DeepSpeedInferenceConfig(
                    hidden_size=hidden_size,
                    heads=num_attention_heads,
                    layer_norm_eps=config.layer_norm_eps if hasattr(
                        config,
                        'layer_norm_eps') else (config.layer_norm_epsilon if hasattr(
                            config,
                            'layer_norm_epsilon') else 1e-12),
                    fp16=fp16,
                    pre_layer_norm=preln,
                    mp_size=mp_size,
                    q_int8=quantize,
                    return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)),
                    triangular_masking=(policy_cls is not HFBertLayerPolicy),
                    local_attention=((config.attention_layers[layer_id] == "local")
                                     if hasattr(config,
                                                'attention_layers') else False),
                    window_size=(config.window_size if hasattr(config,
                                                               'window_size') else 1),
                    rotary_dim=(config.rotary_dim if hasattr(config,
                                                             'rotary_dim') else -1),
                    mlp_after_attn=(policy_cls is not HFGPTJLayerPolicy))

            if quantize and quantize_settings is not None:
                (quantization_scales,
                 merge_count,
                 mlp_extra_grouping,
                 quantize_groups) = quantize_settings
                if moe:
                    new_module = transformer_inference.DeepSpeedMoEInference(
                        transformer_config,
                        mp_group=mp_group,
                        ep_group=None if ep_group is None else ep_group[num_experts],
                        expert_mp_group=None
                        if expert_mp_group is None else expert_mp_group[num_experts],
                        quantize_scales=quantization_scales[layer_id],
                        quantize_groups=quantize_groups,
                        merge_count=merge_count,
                        mlp_extra_grouping=mlp_extra_grouping,
                        qkv_merging=(policy_cls is HFBertLayerPolicy))

                else:
                    new_module = transformer_inference.DeepSpeedTransformerInference(
                        transformer_config,
                        mp_group=mp_group,
                        quantize_scales=quantization_scales[layer_id],
                        quantize_groups=quantize_groups,
                        merge_count=merge_count,
                        mlp_extra_grouping=mlp_extra_grouping,
                        qkv_merging=(policy_cls is HFBertLayerPolicy))

                if quantize and qkvw.dtype != torch.int8:
                    quantize_bits = 8
                    quantizer = WeightQuantization()
                    if policy_cls is HFBertLayerPolicy:
                        data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3)
                    else:
                        data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups)
                    qkvw.data.copy_(data_quantized)
                    qkvw.data = qkvw.data.to(torch.int8)
            else:

                if moe:
                    new_module = transformer_inference.DeepSpeedMoEInference(
                        transformer_config,
                        mp_group=mp_group,
                        ep_group=None if ep_group is None else ep_group[num_experts],
                        expert_mp_group=None
                        if expert_mp_group is None else expert_mp_group[num_experts],
                    )

                else:
                    new_module = transformer_inference.DeepSpeedTransformerInference(
                        transformer_config,
                        mp_group=mp_group,
                    )
            new_module.config.scale_attention = scale_attention

            # we want the weights in [input, output] shape
            # linear layer is created with [input, output] shape
            # transpose it here to reduce inference cost!
            def transpose(data):
                data.view(-1).copy_(data.transpose(-1, -2).contiguous().view(-1))
                data = data.reshape(data.shape[-1], data.shape[-2])
                return data

            if attn_linear_layer:
                qkvw.data = transpose(qkvw.data)
                dense_w.data = transpose(dense_w.data)

            if mlp_linear_layer:
                _h4h_w = [transpose(moe_w1.data)
                          for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data)
                _4hh_w = [transpose(moe_w1.data)
                          for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data)

            if moe and moe_type == 'residual':
                _res_h4h_w.data = transpose(_res_h4h_w.data)
                _res_4hh_w.data = transpose(_res_4hh_w.data)
                _res_coef.data = transpose(_res_coef.data)

            attn_block = new_module.attention
            attn_block.attn_qkvw = mp_replace.qkv_copy(attn_block.attn_qkvw, qkvw)
            attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb, qkvb)

            attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
            attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)

            mpl_block = new_module.mlp
            if moe:
                gpu_index = torch.distributed.get_rank()
                gpu_index = 0
                for ep_index in range(local_ep_size):
                    mpl_block[ep_index].inter_w.data = _h4h_w[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].inter_b.data = _h4h_b[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].output_w.data = _4hh_w[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].output_b.data = _4hh_b[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                new_module.attn_nw.data = attn_nw.to(torch.cuda.current_device())
                new_module.attn_nb.data = attn_nb.to(torch.cuda.current_device())
                if moe_type == 'residual':
                    new_module.res_mlp.inter_w.data = _res_h4h_w.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.inter_b.data = _res_h4h_b.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.output_w.data = _res_4hh_w.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.output_b.data = _res_4hh_b.to(
                        torch.cuda.current_device())
                    new_module.res_coef.data = _res_coef.to(torch.cuda.current_device())
            else:
                mpl_block.inter_w.data = mp_replace.copy(mpl_block.inter_w, _h4h_w)
                mpl_block.inter_b.data = mp_replace.copy(mpl_block.inter_b, _h4h_b)
                mpl_block.output_w.data = mp_replace.copy(mpl_block.output_w, _4hh_w)
                mpl_block.output_b.data = mp_replace.copy(mpl_block.output_b, _4hh_b)
                if attn_nw is None:
                    new_module.mlp.attn_nw = attn_nw
                else:
                    new_module.mlp.attn_nw.data = attn_nw.to(torch.cuda.current_device())
                if attn_nb is None:
                    new_module.mlp.attn_nb = attn_nb
                else:
                    new_module.mlp.attn_nb.data = attn_nb.to(torch.cuda.current_device())
            new_module.norm_w.data = input_nw.to(torch.cuda.current_device())
            new_module.norm_b.data = input_nb.to(torch.cuda.current_device())
        else:
            transformer_config = deepspeed.DeepSpeedTransformerConfig(
                batch_size=micro_batch_size,
                hidden_size=config.hidden_size,
                heads=config.num_attention_heads,
                attn_dropout_ratio=config.attention_probs_dropout_prob,
                hidden_dropout_ratio=config.hidden_dropout_prob,
                num_hidden_layers=config.num_hidden_layers,
                initializer_range=config.initializer_range,
                layer_norm_eps=config.layer_norm_eps if hasattr(
                    config,
                    'layer_norm_eps') else 1e-12,
                seed=seed,
                fp16=fp16,
                pre_layer_norm=(False if policy_cls is HFBertLayerPolicy else preln),
                return_tuple=return_tuple,
                local_rank=local_rank,
                stochastic_mode=stochastic_mode,
                normalize_invertible=True,
                training=training)
            new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
            new_module.attn_qkvw.data = qkvw
            new_module.attn_qkvb.data = qkvb
            new_module.attn_ow.data = dense_w
            new_module.attn_ob.data = dense_b

            new_module.attn_nw.data = attn_nw
            new_module.attn_nb.data = attn_nb
            new_module.norm_w.data = input_nw
            new_module.norm_b.data = input_nb

            new_module.inter_w.data = _h4h_w
            new_module.inter_b.data = _h4h_b
            new_module.output_w.data = _4hh_w
            new_module.output_b.data = _4hh_b
        return new_module
Example #5
0
    def replace_with_policy(child,
                            policy_cls,
                            triangular_masking,
                            inference=False,
                            layer_id=0):
        policy = policy_cls(child, inference=inference)

        if inference:
            hidden_size, num_attention_heads = policy.get_hidden_heads()
            assert num_attention_heads % mp_size == 0,\
                "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\
                "This is because the attention computation is partitioned evenly among the parallel GPUs."
        from deepspeed.moe.layer import MoE
        moe = False
        if hasattr(child, 'mlp') and isinstance(child.mlp, MoE):
            num_experts = child.mlp.num_experts
            moe = True

        attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention()
        if not moe or moe_type == 'standard':
            mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp()
        else:
            mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b, \
                _res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type)

        attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm()
        if quantize:
            if policy_cls is not HFBertLayerPolicy:
                qkvw = qkvw.to(torch.int8)
            dense_w = dense_w.to(torch.int8)
            _h4h_w = [moe_w1.to(torch.int8)
                      for moe_w1 in _h4h_w] if moe else _h4h_w.to(torch.int8)
            _4hh_w = [moe_w1.to(torch.int8)
                      for moe_w1 in _4hh_w] if moe else _4hh_w.to(torch.int8)
        elif fp16:
            qkvw = qkvw.half()
            dense_w = dense_w.half()
            _h4h_w = [moe_w1.half() for moe_w1 in _h4h_w] if moe else _h4h_w.half()
            _4hh_w = [moe_w1.half() for moe_w1 in _4hh_w] if moe else _4hh_w.half()
        if quantize or fp16:
            qkvb = qkvb if qkvb is None else qkvb.half()
            dense_b = dense_b if dense_b is None else dense_b.half()
            _h4h_b = [moe_b1.half() for moe_b1 in _h4h_b] if moe else _h4h_b.half()
            _4hh_b = [moe_b1.half() for moe_b1 in _4hh_b] if moe else _4hh_b.half()
            attn_nw = attn_nw if attn_nw is None else attn_nw.half()
            attn_nb = attn_nb if attn_nb is None else attn_nb.half()
            input_nw = input_nw.half()
            input_nb = input_nb.half()

        if moe and moe_type == 'residual' and fp16:
            _res_h4h_b = _res_h4h_b.half()
            _res_4hh_b = _res_4hh_b.half()
            _res_h4h_w = _res_h4h_w.half()
            _res_4hh_w = _res_4hh_w.half()
            _res_coef = _res_coef.half()

        #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group)

        if inference:
            if moe:
                ep_world_size = dist.get_world_size()
                local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size

                transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig(
                    hidden_size=hidden_size,
                    heads=num_attention_heads,
                    layer_norm_eps=config.layer_norm_eps if hasattr(
                        config,
                        'layer_norm_eps') else 1e-12,
                    fp16=fp16,
                    pre_layer_norm=policy.pre_attn_norm,
                    mp_size=mp_size,
                    q_int8=quantize,
                    moe_experts=local_ep_size,
                    global_experts=num_experts,
                    mlp_type=moe_type)
            else:
                rotary_dim = config.rotary_dim if hasattr(config, 'rotary_dim') else child.attention.rotary_ndims \
                                            if hasattr(child, 'attention') and hasattr(child.attention,'rotary_ndims') else -1
                bigscience_bloom = policy_cls is BLOOMLayerPolicy
                transformer_config = transformer_inference.DeepSpeedInferenceConfig(
                    hidden_size=hidden_size,
                    heads=num_attention_heads,
                    layer_norm_eps=config.layer_norm_eps if hasattr(
                        config,
                        'layer_norm_eps') else
                    (config.layer_norm_epsilon
                     if hasattr(config,
                                'layer_norm_epsilon') else config.layernorm_epsilon
                     if hasattr(config,
                                'layernorm_epsilon') else 1.0e-12),
                    fp16=fp16,
                    pre_layer_norm=policy.pre_attn_norm,
                    mp_size=mp_size,
                    q_int8=quantize,
                    return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)),
                    triangular_masking=(policy_cls is not HFBertLayerPolicy),
                    local_attention=((config.attention_layers[layer_id] == "local")
                                     if hasattr(config,
                                                'attention_layers') else False),
                    window_size=(config.window_size if hasattr(config,
                                                               'window_size') else 1),
                    rotary_dim=rotary_dim,
                    mlp_after_attn=(rotary_dim is None or rotary_dim < 0),
                    mlp_act_func_type=policy.mlp_act_func_type,
                    training_mp_size=training_mp_size,
                    bigscience_bloom=bigscience_bloom)

            if quantize and quantize_settings is not None:
                (quantization_scales,
                 merge_count,
                 mlp_extra_grouping,
                 quantize_groups) = quantize_settings
                if moe:
                    new_module = transformer_inference.DeepSpeedMoEInference(
                        transformer_config,
                        mp_group=mp_group,
                        ep_group=None if ep_group is None else ep_group[num_experts],
                        expert_mp_group=None
                        if expert_mp_group is None else expert_mp_group[num_experts],
                        quantize_scales=quantization_scales[layer_id],
                        quantize_groups=quantize_groups,
                        merge_count=merge_count,
                        mlp_extra_grouping=mlp_extra_grouping,
                        qkv_merging=(policy_cls is HFBertLayerPolicy))

                else:
                    new_module = transformer_inference.DeepSpeedTransformerInference(
                        transformer_config,
                        mp_group=mp_group,
                        quantize_scales=quantization_scales[layer_id],
                        quantize_groups=quantize_groups,
                        merge_count=merge_count,
                        mlp_extra_grouping=mlp_extra_grouping,
                        qkv_merging=(policy_cls is HFBertLayerPolicy))

                if quantize and qkvw.dtype != torch.int8:
                    quantize_bits = 8
                    quantizer = WeightQuantization()
                    if policy_cls is HFBertLayerPolicy:
                        data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3)
                    else:
                        data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups)
                    qkvw.data.copy_(data_quantized)
                    qkvw.data = qkvw.data.to(torch.int8)
            else:

                if moe:
                    new_module = transformer_inference.DeepSpeedMoEInference(
                        transformer_config,
                        mp_group=mp_group,
                        ep_group=None if ep_group is None else ep_group[num_experts],
                        expert_mp_group=None
                        if expert_mp_group is None else expert_mp_group[num_experts],
                    )

                else:
                    new_module = transformer_inference.DeepSpeedTransformerInference(
                        transformer_config,
                        mp_group=mp_group,
                    )
            new_module.config.scale_attention = scale_attention

            # we want the weights in [input, output] shape
            # linear layer is created with [input, output] shape
            # transpose it here to reduce inference cost!
            def transpose(data):
                # temp move to cpu to avoid requiring extra GPU memory during the reshape
                data = data.to('cpu')
                data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
                data = data.reshape(data.shape[-1], data.shape[-2])
                data.to(torch.cuda.current_device())
                return data

            attn_block = new_module.attention
            mpl_block = new_module.mlp

            if attn_linear_layer:
                if qkvw.numel() == 0 or qkvw.is_meta:
                    if qkvw.is_meta or qkvw.ds_tensor.numel(
                    ) < attn_block.attn_qkvw.numel():
                        pass
                    else:
                        with GatheredParameters([qkvw,
                                                 dense_w,
                                                 qkvb,
                                                 dense_b],
                                                modifier_rank=0):
                            qkvw = transpose(qkvw.data)
                            dense_w = transpose(dense_w.data)
                            qkvb = qkvb.data
                            dense_b = dense_b.data
                else:
                    qkvw.data = transpose(qkvw.data)
                    dense_w.data = transpose(dense_w.data)

            def _transpose(x):
                num_attention_heads_per_partition = transformer_config.heads // transformer_config.mp_size
                attention_head_size = x.shape[-1] // num_attention_heads_per_partition
                new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition,
                                               attention_head_size)
                x_1 = x.view(*new_x_shape)
                (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1))
                if len(q.shape) > 2:
                    return torch.cat((q.reshape(q.shape[0],
                                                -1),
                                      k.reshape(q.shape[0],
                                                -1),
                                      v.reshape(q.shape[0],
                                                -1)),
                                     dim=-1).reshape(x.shape)
                else:
                    return torch.cat((q.reshape(-1),
                                      k.reshape(-1),
                                      v.reshape(-1)),
                                     dim=-1).reshape(x.shape)

            if megatron_v2:
                new_module.config.rotate_half = True
                new_module.config.rotate_every_two = False

                # Note: this part needs to be added for BLOOM architecture
                qkvw = torch.nn.parameter.Parameter(_transpose(qkvw).contiguous())
                qkvb = torch.nn.parameter.Parameter(_transpose(qkvb).contiguous())

            # NOTE: This part caused instability in the multi-GPU inference!
            # TODO: This needs to be incorporated in the kernels.
            #dense_b = dense_b if dense_b is None else dense_b * (
            #    transformer_config.training_mp_size / transformer_config.mp_size)
            #_4hh_b = _4hh_b * (transformer_config.training_mp_size /
            #                   transformer_config.mp_size)

            if mlp_linear_layer:
                if not moe and (_4hh_w.numel() == 0 or _4hh_w.is_meta):
                    if _4hh_w.is_meta or _4hh_w.ds_tensor.numel(
                    ) < mpl_block.inter_w.numel():
                        pass
                    else:
                        with GatheredParameters([_h4h_w,
                                                 _4hh_w,
                                                 _4hh_b,
                                                 _h4h_b],
                                                modifier_rank=0):
                            _h4h_w = transpose(_h4h_w.data)
                            _4hh_w = transpose(_4hh_w.data)
                            _h4h_b = _h4h_b.data
                            _4hh_b = _4hh_b.data
                else:
                    _h4h_w = [transpose(moe_w1.data)
                              for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data)
                    _4hh_w = [transpose(moe_w1.data)
                              for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data)

            if moe and moe_type == 'residual':
                _res_h4h_w.data = transpose(_res_h4h_w.data)
                _res_4hh_w.data = transpose(_res_4hh_w.data)
                _res_coef.data = transpose(_res_coef.data)

            if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta:
                if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel():
                    pass
                else:
                    with GatheredParameters([
                            attn_block.attn_qkvw,
                            attn_block.attn_qkvb,
                            attn_block.attn_ow,
                            attn_block.attn_ob
                    ],
                                            modifier_rank=0):
                        attn_block.attn_qkvw = mp_replace.copy(
                            attn_block.attn_qkvw,
                            qkvw)
                        attn_block.attn_qkvb = mp_replace.copy(
                            attn_block.attn_qkvb,
                            qkvb)

                        attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
                        attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)
            else:
                if bigscience_bloom:
                    attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw)
                    attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb)
                else:
                    attn_block.attn_qkvw = mp_replace.qkv_copy(
                        attn_block.attn_qkvw,
                        qkvw)
                    attn_block.attn_qkvb = mp_replace.qkv_copy(
                        attn_block.attn_qkvb,
                        qkvb)

                attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
                attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)

            if moe:
                gpu_index = dist.get_rank()
                gpu_index = 0
                for ep_index in range(local_ep_size):
                    mpl_block[ep_index].inter_w.data = _h4h_w[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].inter_b.data = _h4h_b[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].output_w.data = _4hh_w[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].output_b.data = _4hh_b[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                new_module.attn_nw.data = attn_nw.to(torch.cuda.current_device())
                new_module.attn_nb.data = attn_nb.to(torch.cuda.current_device())
                if moe_type == 'residual':
                    new_module.res_mlp.inter_w.data = _res_h4h_w.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.inter_b.data = _res_h4h_b.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.output_w.data = _res_4hh_w.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.output_b.data = _res_4hh_b.to(
                        torch.cuda.current_device())
                    new_module.res_coef.data = _res_coef.to(torch.cuda.current_device())
            else:

                if _4hh_w.numel() == 0 or _4hh_w.is_meta:
                    if _4hh_w.is_meta or _4hh_w.ds_tensor.numel(
                    ) < mpl_block.inter_w.numel():
                        pass
                    else:
                        with GatheredParameters([_h4h_w,
                                                 _4hh_w,
                                                 _4hh_w,
                                                 _4hh_b],
                                                modifier_rank=0):
                            mpl_block.inter_w = mp_replace.copy(
                                mpl_block.inter_w,
                                _h4h_w)
                            mpl_block.inter_b = mp_replace.copy(
                                mpl_block.inter_b,
                                _h4h_b)
                            mpl_block.output_w = mp_replace.copy(
                                mpl_block.output_w,
                                _4hh_w)
                            mpl_block.output_b = mp_replace.copy(
                                mpl_block.output_b,
                                _4hh_b)
                else:
                    mpl_block.inter_w = mp_replace.copy(mpl_block.inter_w, _h4h_w)
                    mpl_block.inter_b = mp_replace.copy(mpl_block.inter_b, _h4h_b)
                    mpl_block.output_w = mp_replace.copy(mpl_block.output_w, _4hh_w)
                    mpl_block.output_b = mp_replace.copy(mpl_block.output_b, _4hh_b)

                if attn_nw is None:
                    new_module.mlp.attn_nw = attn_nw
                    new_module.mlp.attn_nb = attn_nb
                else:
                    if attn_nw.is_meta or attn_nw.numel() == 0:
                        if attn_nw.is_meta or attn_nw.ds_tensor.numel(
                        ) < new_module.mlp.attn_nw.numel():
                            pass
                        else:
                            with GatheredParameters([attn_nw, attn_nb], modifier_rank=0):
                                new_module.mlp.attn_nw.data.copy_(
                                    attn_nw.to(torch.cuda.current_device()))
                                new_module.mlp.attn_nb.data.copy_(
                                    attn_nb.to(torch.cuda.current_device()))
                    else:
                        new_module.mlp.attn_nw.data.copy_(
                            attn_nw.to(torch.cuda.current_device()))
                        new_module.mlp.attn_nb.data.copy_(
                            attn_nb.to(torch.cuda.current_device()))

            if input_nw.is_meta or input_nw.numel() == 0:
                if input_nw.is_meta or input_nw.ds_tensor.numel(
                ) < new_module.norm_w.numel():
                    pass
                else:
                    with GatheredParameters([input_nw, input_nb], modifier_rank=0):
                        new_module.norm_w.data.copy_(
                            input_nw.to(torch.cuda.current_device()))
                        new_module.norm_b.data.copy_(
                            input_nb.to(torch.cuda.current_device()))
            else:
                new_module.norm_w.data.copy_(input_nw.to(torch.cuda.current_device()))
                new_module.norm_b.data.copy_(input_nb.to(torch.cuda.current_device()))
        else:
            transformer_config = deepspeed.DeepSpeedTransformerConfig(
                batch_size=micro_batch_size if micro_batch_size > 0 else 1,
                hidden_size=config.hidden_size,
                heads=config.num_attention_heads,
                attn_dropout_ratio=config.attention_probs_dropout_prob,
                hidden_dropout_ratio=config.hidden_dropout_prob,
                num_hidden_layers=config.num_hidden_layers,
                initializer_range=config.initializer_range,
                layer_norm_eps=config.layer_norm_eps if hasattr(
                    config,
                    'layer_norm_eps') else 1e-12,
                seed=seed,
                fp16=fp16,
                pre_layer_norm=policy.pre_attn_norm,
                return_tuple=return_tuple,
                local_rank=local_rank,
                stochastic_mode=stochastic_mode,
                normalize_invertible=True,
                training=training)
            new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
            new_module.attn_qkvw.data = qkvw
            new_module.attn_qkvb.data = qkvb
            new_module.attn_ow.data = dense_w
            new_module.attn_ob.data = dense_b

            new_module.attn_nw.data = attn_nw
            new_module.attn_nb.data = attn_nb
            new_module.norm_w.data = input_nw
            new_module.norm_b.data = input_nb

            new_module.inter_w.data = _h4h_w
            new_module.inter_b.data = _h4h_b
            new_module.output_w.data = _4hh_w
            new_module.output_b.data = _4hh_b
        return new_module