Esempio n. 1
0
def print_json_dist(message, ranks=None, path=None):
    from deepspeed import comm as dist
    """Print message when one of following condition meets

    + not dist.is_initialized()
    + dist.get_rank() in ranks if ranks is not None or ranks = [-1]

    Args:
        message (str)
        ranks (list)
        path (str)

    """
    should_log = not dist.is_initialized()
    ranks = ranks or []
    my_rank = dist.get_rank() if dist.is_initialized() else -1
    if ranks and not should_log:
        should_log = ranks[0] == -1
        should_log = should_log or (my_rank in set(ranks))
    if should_log:
        message['rank'] = my_rank
        import json
        with open(path, 'w') as outfile:
            json.dump(message, outfile)
            os.fsync(outfile)
Esempio n. 2
0
def see_memory_usage(message, force=False):
    if not force:
        return
    if dist.is_initialized() and not dist.get_rank() == 0:
        return

    # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
    gc.collect()

    # Print message except when distributed but not rank 0
    logger.info(message)
    logger.info(
        f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
        Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
        CA {round(torch_memory_reserved() / (1024 * 1024 * 1024),2)} GB \
        Max_CA {round(torch_max_memory_reserved() / (1024 * 1024 * 1024))} GB "
    )

    vm_stats = psutil.virtual_memory()
    used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)
    logger.info(
        f'CPU Virtual Memory:  used = {used_GB} GB, percent = {vm_stats.percent}%'
    )

    # get the peak memory to report correct data, so reset the counter for the next call
    if hasattr(torch.cuda, "reset_peak_memory_stats"):  # pytorch 1.4+
        torch.cuda.reset_peak_memory_stats()
Esempio n. 3
0
    def _dist_init(self, local_rank, num_procs, skip_msg):
        """Initialize deepspeed.comm and execute the user function. """
        if self.set_dist_env:
            os.environ['MASTER_ADDR'] = '127.0.0.1'
            os.environ['MASTER_PORT'] = get_master_port()
            os.environ['LOCAL_RANK'] = str(local_rank)
            # NOTE: unit tests don't support multi-node so local_rank == global rank
            os.environ['RANK'] = str(local_rank)
            os.environ['WORLD_SIZE'] = str(num_procs)

        # turn off NCCL logging if set
        os.environ.pop('NCCL_DEBUG', None)

        set_cuda_visibile()

        if self.init_distributed:
            deepspeed.init_distributed(dist_backend=self.backend)
            dist.barrier()

        if torch.cuda.is_available():
            torch.cuda.set_device(local_rank)

        try:
            self.current_test(**self.test_kwargs)
        except BaseException as e:
            if isinstance(e, Skipped):
                skip_msg.put(e.msg)
            else:
                raise e

        if self.init_distributed or dist.is_initialized():
            # make sure all ranks finish at the same time
            dist.barrier()
            # tear down after test completes
            dist.destroy_process_group()
Esempio n. 4
0
    def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
        is_pipe_parallel = isinstance(self.module, PipelineModule)
        if is_pipe_parallel:
            raise RuntimeError(
                'pipeline parallelism is currently not supported in inference.')
        if os.path.isdir(load_dir):
            if tag is None:
                latest_path = os.path.join(load_dir, "latest")
                if os.path.isfile(latest_path):
                    with open(latest_path, "r") as fd:
                        tag = fd.read().strip()

            ckpt_list = self._get_all_ckpt_names(load_dir, tag)
            sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine)
        else:
            sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir)

        if type(sd_loader) is list:
            self.sd = torch.load(sd_loader[0], map_location='cpu')
            self.key_list = list(self.sd.keys())

            self.load_model_with_checkpoint(self.module)

            for i in range(1, len(sd_loader)):
                if not dist.is_initialized() or dist.get_rank() == 0:
                    print(f"loading checkpoint ({i})")
                self.sd = torch.load(sd_loader[i], map_location='cuda')
                self.key_list = list(self.sd.keys())
                self.load_model_with_checkpoint(self.module)
        else:
            mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()

            load_path, checkpoint, quantize_config = sd_loader.load(self.mp_world_size,
                                                    mp_rank,
                                                    is_pipe_parallel=is_pipe_parallel,
                                                    quantize=(self.dtype is torch.int8),
                                                    quantize_groups=self.quantize_groups,
                                                    mlp_extra_grouping=self.mlp_extra_grouping)

            self.quantization_scales, self.quantize_merge_count = quantize_config

            moe, _ = has_moe_layers(self.module)
            if moe:
                from deepspeed.runtime.engine import DeepSpeedEngine
                old_moe_load = False
                if not isinstance(checkpoint['num_experts'], list):
                    old_moe_load = True
                DeepSpeedEngine.load_moe_state_dict(
                    load_dir,
                    tag,
                    state_dict=checkpoint[self._choose_module_key(checkpoint)],
                    old_moe_load=old_moe_load,
                    model=self.module,
                    mpu=self.mpu,
                    checkpoint_engine=self.checkpoint_engine)

            self.module.load_state_dict(
                state_dict=checkpoint[self._choose_module_key(checkpoint)],
                checkpoint_engine=self.checkpoint_engine,
                strict=load_module_strict)
Esempio n. 5
0
def create_deepspeed_args():
    parser = argparse.ArgumentParser()
    args = parser.parse_args(args='')
    args.deepspeed = True
    if dist.is_initialized():
        # We assume up to one full node executing unit tests
        assert dist.get_world_size() <= torch.cuda.device_count()
        args.local_rank = dist.get_rank()
    return args
Esempio n. 6
0
def _get_data_parallel_group():
    """Get the data parallel group the caller rank belongs to."""
    assert dist.is_initialized(), \
        'dist is not initialized'
    global mpu
    if mpu is not None:
        return mpu.get_data_parallel_group()
    # Return the clone of dist world group
    return _clone_world_group()
Esempio n. 7
0
def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu):
    """
        Create expert and data parallel groups based on MPU (model parallel) group.

        Note: Caller of this function is responsible to check if the groups already exist.

        Example - E + M + D parallel
        world_size = 16
        model_degree = 2
        expert_degree = 4 # number of experts in same group
        mp_group = [0, 1], [2,3], [4,5] ...
        data_parallel_group =[0,2,4,6,8,10, 12,14],                 [1,3,5,7,9,11,13,15]
        expert_parallel_group = [0,2,4,6], [8,10,12,14]             [1,3,5,7], [9,11,13,15]
        expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14],    [1,9],[3,11],[5,13],[7,15]
    """
    assert dist.is_initialized(), "dist is not initialized"
    model_parallel_size_ = mpu.get_model_parallel_world_size()

    global expert_tensor_parallel_world_size
    expert_tensor_parallel_world_size = model_parallel_size_

    world_size = dist.get_world_size()
    rank = dist.get_rank()
    dp_world_size = mpu.get_data_parallel_world_size()
    dp_rank = mpu.get_data_parallel_rank()

    _ensure_divisibility(world_size, model_parallel_size_)
    _ensure_divisibility(dp_world_size, expert_parallel_size_)

    log_dist(
        f"Creating deepspeed groups with model parallel size {model_parallel_size_}, expert parallel size {expert_parallel_size_}, world size {world_size}, dp world size {dp_world_size}",
        [0])

    global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP

    # Get world size and rank. Ensure some consistencies.
    _DATA_PARALLEL_GROUP = mpu.get_data_parallel_group()
    _MODEL_PARALLEL_GROUP = mpu.get_model_parallel_group()

    group_name = f"ep_size_{expert_parallel_size_}"

    # Only create groups if they don't already exist
    # Need to check conditions outside the group creation loop because of the way torch.dist group creation works
    if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP:
        expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(
            world_size, model_parallel_size_, expert_parallel_size_)
        for ranks in expert_parallel_groups:
            group = dist.new_group(ranks)
            if rank in list(ranks):
                _EXPERT_PARALLEL_GROUP[group_name] = group

        for ranks in expert_data_parallel_groups:
            group = dist.new_group(ranks)
            if rank in list(ranks):
                _EXPERT_DATA_PARALLEL_GROUP[group_name] = group
Esempio n. 8
0
    def __init__(self, typename, *module_args, **module_kwargs):
        self.typename = typename
        self.module_args = module_args
        self.module_kwargs = module_kwargs

        if not issubclass(typename, nn.Module):
            raise RuntimeError('LayerSpec only supports torch.nn.Module types.')

        if dist.is_initialized():
            self.global_rank = dist.get_rank()
        else:
            self.global_rank = -1
Esempio n. 9
0
def _create_expert_and_data_parallel(expert_parallel_size_):
    """
        Create expert and data parallel groups.

        Note: Caller of this function is responsible to check if the groups already exist.

        Example - E + D parallel
        world_size = 16
        expert_parallel_size = 2 # number of experts in same group
        expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params
        expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all
        data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE
    """
    assert dist.is_initialized()

    log_dist(
        f'Creating expert and data parallel groups with size {expert_parallel_size_}',
        ranks=[0])
    world_size = dist.get_world_size()
    rank = dist.get_rank()

    _ensure_divisibility(world_size, expert_parallel_size_)

    group_name = f"ep_size_{expert_parallel_size_}"

    # Build the expert data parallel groups.
    global _EXPERT_DATA_PARALLEL_GROUP

    # Only create group if it does not already exist
    if group_name not in _EXPERT_DATA_PARALLEL_GROUP:
        for i in range(expert_parallel_size_):
            ranks = range(i, world_size, expert_parallel_size_)
            group = dist.new_group(ranks)
            log_dist(
                f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}',
                [0])
            if i == (rank % expert_parallel_size_):
                _EXPERT_DATA_PARALLEL_GROUP[group_name] = group

    # Build the expert parallel groups.
    global _EXPERT_PARALLEL_GROUP

    # Only create group if it does not already exist
    if group_name not in _EXPERT_PARALLEL_GROUP:
        for i in range(world_size // expert_parallel_size_):
            ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
            group = dist.new_group(ranks)
            log_dist(
                f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}',
                [0])
            if i == (rank // expert_parallel_size_):
                _EXPERT_PARALLEL_GROUP[group_name] = group
Esempio n. 10
0
def _clone_world_group():
    """Create a clone of the world group
        Note: We need to clone the dist world group because we
        use dist.get_global_rank() utility function in DeepSpeed at many places.
        As that function does not work on dist.group.WORLD, we
        need to keep a clone of it.
    """
    assert dist.is_initialized(), "dist is not initialized"
    global _WORLD_GROUP
    if _WORLD_GROUP is None:
        # If not cloned already, clone the world group
        _WORLD_GROUP = dist.new_group(ranks=range(dist.get_world_size()))
    return _WORLD_GROUP
Esempio n. 11
0
def log_dist(message, ranks=None, level=logging.INFO):
    from deepspeed import comm as dist
    """Log message when one of following condition meets

    + not dist.is_initialized()
    + dist.get_rank() in ranks if ranks is not None or ranks = [-1]

    Args:
        message (str)
        ranks (list)
        level (int)

    """
    should_log = not dist.is_initialized()
    ranks = ranks or []
    my_rank = dist.get_rank() if dist.is_initialized() else -1
    if ranks and not should_log:
        should_log = ranks[0] == -1
        should_log = should_log or (my_rank in set(ranks))
    if should_log:
        final_message = "[Rank {}] {}".format(my_rank, message)
        logger.log(level, final_message)
Esempio n. 12
0
    def save_exp_results_to_database(self, message, ranks=None, path=None):
        """Print message when one of following condition meets

        + not dist.is_initialized()
        + dist.get_rank() in ranks if ranks is not None or ranks = [-1]

    Args:
            message (str)
            ranks (list)
            path (str)

        """
        should_log = not dist.is_initialized()
        ranks = ranks or []
        my_rank = dist.get_rank() if dist.is_initialized() else -1
        if ranks and not should_log:
            should_log = ranks[0] == -1
            should_log = should_log or (my_rank in set(ranks))
        logger.debug(f"*** Should log: {should_log}")
        if should_log:
            message['rank'] = my_rank
            with open(path, 'a') as outfile:
                json.dump(message, outfile)
                outfile.write('\n')
Esempio n. 13
0
def _create_model_parallel(model_parallel_size_):
    """
    Initialize model data parallel groups.

    Arguments:
        model_parallel_size: number of GPUs used to parallelize model.

    Returns:
        Tuple of data parallel group and model parallel group

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
    use 2 GPUs to parallelize the model. The present function will
    create 4 model parallel groups and 2 data parallel groups as:
        4 model parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 data parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    log_dist(f'Creating model parallel group with size {model_parallel_size_}',
             ranks=[0])
    # Get world size and rank. Ensure some consistencies.
    assert dist.is_initialized()
    world_size = dist.get_world_size()
    model_parallel_size = min(model_parallel_size_, world_size)
    _ensure_divisibility(world_size, model_parallel_size)
    rank = dist.get_rank()

    _DATA_PARALLEL_GROUP = None
    _MODEL_PARALLEL_GROUP = None
    # Build the data parallel groups.
    for i in range(model_parallel_size):
        ranks = range(i, world_size, model_parallel_size)
        group = dist.new_group(ranks)
        if i == (rank % model_parallel_size):
            _DATA_PARALLEL_GROUP = group

    # Build the model parallel groups.
    for i in range(world_size // model_parallel_size):
        ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
        group = dist.new_group(ranks)
        if i == (rank // model_parallel_size):
            _MODEL_PARALLEL_GROUP = group

    return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP
Esempio n. 14
0
        def backup_attention(mixed_x_layer, layer_past, alibi, input_mask,
                             norm_factor):
            alibi = alibi.to(torch.cuda.current_device())
            head_dim = hidden_size_per_partition // num_attention_heads_per_partition
            new_tensor_shape = mixed_x_layer.size()[:-1] + (
                num_attention_heads_per_partition, 3 * head_dim)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            (query_layer, key_layer,
             value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

            if layer_past is not None:
                past_key, past_value = layer_past
                # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
                key_layer = torch.cat((past_key.type_as(key_layer), key_layer),
                                      dim=1)
                value_layer = torch.cat(
                    (past_value.type_as(value_layer), value_layer), dim=1)

            presents = (key_layer, value_layer)

            # [batch_size, head_dim, q_length, k_length]
            output_size = (query_layer.size(0), query_layer.size(2),
                           query_layer.size(1), key_layer.size(1))
            # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
            query_layer = query_layer.transpose(1, 0).reshape(
                output_size[2], output_size[0] * output_size[1], -1)
            # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
            key_layer = key_layer.transpose(1, 0).reshape(
                output_size[3], output_size[0] * output_size[1], -1)

            # Raw attention scores. [batch_size * num_heads, q_length, k_length]
            matmul_result = torch.matmul(
                query_layer.transpose(1, 0),
                key_layer.transpose(1, 0).transpose(1, 2))
            # change view to [batch_size, num_heads, q_length, k_length]
            attention_scores = matmul_result.view(*output_size)

            offset = dist.get_rank(
            ) * num_attention_heads_per_partition if dist.is_initialized(
            ) else 0
            attention_probs = inference_cuda_module.softmax_fp16(
                attention_scores,
                ((1 - input_mask).half() *
                 minus_inf) if input_mask.dtype == torch.int64 else input_mask,
                alibi, (config.triangular_masking and
                        (attention_scores.shape[-2] > 1)), False, False, 1,
                False, 1 / (norm_factor * norm_factor), offset, config.mp_size)
            # change view [batch_size x num_heads, q_length, k_length]
            attention_probs_reshaped = attention_probs.view(
                *matmul_result.shape)

            # matmul: [batch_size * num_heads, q_length, head_dim]
            context_layer = torch.bmm(
                attention_probs_reshaped,
                value_layer.transpose(1, 2).reshape(-1, value_layer.size(1),
                                                    value_layer.size(3)))

            # change view [batch_size, num_heads, q_length, head_dim]
            context_layer = context_layer.view(
                context_layer.size(0) // num_attention_heads_per_partition,
                num_attention_heads_per_partition, context_layer.size(1),
                context_layer.shape[-1])

            context_layer = _transpose_for_context(context_layer)

            return context_layer, presents
Esempio n. 15
0
def get_ma_status():
    if dist.is_initialized() and not dist.get_rank() == 0:
        return 0
    return torch.cuda.memory_allocated()
Esempio n. 16
0
def replace_transformer_layer(orig_layer_impl,
                              model,
                              policy=None,
                              micro_batch_size=-1,
                              config=None,
                              seed=-1,
                              hidden_size=-1,
                              num_attention_heads=-1,
                              mp_size=1,
                              training_mp_size=1,
                              mp_group=None,
                              ep_group=None,
                              expert_mp_group=None,
                              fp16=True,
                              local_rank=-1,
                              stochastic_mode=True,
                              training=True,
                              quantize=False,
                              quantize_settings=None,
                              triangular_masking=False,
                              return_tuple=True,
                              replace_with_kernel_inject=False,
                              linear_layer_setting=None,
                              moe=False,
                              moe_experts=1,
                              moe_type='standard',
                              checkpoint_dict=None,
                              save_mp_checkpoint_path=None):
    """ Replace bert-style transformer layers with DeepSpeed's transformer layer
    Arguments:
        orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
            e.g., transformers.modeling_bert.BertLayer.
        model (torch.nn.Module): user's nn.module representing their model
        policy: shows the policy for mapping from the orig_layer_impl to transformer parameters when
            replace_with_kernel_inject is set, otherwise, it provides the names of two linear layers as
            a tuple: (attention_output projection, transformer output projection)
        micro_batch_size (int): micro batch size per gpu used during training/eval
        config (dict): model config containing hidden size, attention heads, etc.
        seed (int): random seed value
        max_seq_length (int): max sequence length for training
        hidden_size (int): hidden dimension
        num_attention_heads (int): number of attention heads
        mp_size (int): model_parallelism degree
        mp_group : model_parallel group initialized on the modeling side
        preln (bool): does the original layer implementation do pre or post layer norm?
        fp16 (bool): fp16 or fp32
        local_rank (int): GPU rank (optional),
        stochastic_mode (bool): whether to use stochastic mode
        training (bool): specifying whether kernel-injection is done for training/inference (set to false for inference-mode injection)
        quantize_settings (tuple): this setting shows how we can quantize a model for running it through the inference kernels.
                It includes (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups).
        return_tuple (bool): if set, transformer layer returns a tuple as the output.
            Note: this flag needs to be set for huggingface models.
        replace_with_kernel_inject (bool): injection_mode, if true, kernels will be add along with configuring
            Tensor-Parallelism
        linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers
            and embedding layers
        attention_params: (list of strings) [Optional]: shows the parameters in the attention part that needs to
            be adjusted based on the model-parallelism
    Returns:
        Updated nn.module with replaced transformer layers
    """
    mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group,
                                          mp_size=mp_size)  #, out_dim=0, in_dim=1)

    def replace_with_policy(child,
                            policy_cls,
                            triangular_masking,
                            inference=False,
                            layer_id=0):
        policy = policy_cls(child, inference=inference)

        if inference:
            hidden_size, num_attention_heads = policy.get_hidden_heads()
            assert num_attention_heads % mp_size == 0,\
                "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\
                "This is because the attention computation is partitioned evenly among the parallel GPUs."
        from deepspeed.moe.layer import MoE
        moe = False
        if hasattr(child, 'mlp') and isinstance(child.mlp, MoE):
            num_experts = child.mlp.num_experts
            moe = True

        attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention()
        if not moe or moe_type == 'standard':
            mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp()
        else:
            mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b, \
                _res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type)

        attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm()
        if quantize:
            if policy_cls is not HFBertLayerPolicy:
                qkvw = qkvw.to(torch.int8)
            dense_w = dense_w.to(torch.int8)
            _h4h_w = [moe_w1.to(torch.int8)
                      for moe_w1 in _h4h_w] if moe else _h4h_w.to(torch.int8)
            _4hh_w = [moe_w1.to(torch.int8)
                      for moe_w1 in _4hh_w] if moe else _4hh_w.to(torch.int8)
        elif fp16:
            qkvw = qkvw.half()
            dense_w = dense_w.half()
            _h4h_w = [moe_w1.half() for moe_w1 in _h4h_w] if moe else _h4h_w.half()
            _4hh_w = [moe_w1.half() for moe_w1 in _4hh_w] if moe else _4hh_w.half()
        if quantize or fp16:
            qkvb = qkvb if qkvb is None else qkvb.half()
            dense_b = dense_b if dense_b is None else dense_b.half()
            _h4h_b = [moe_b1.half() for moe_b1 in _h4h_b] if moe else _h4h_b.half()
            _4hh_b = [moe_b1.half() for moe_b1 in _4hh_b] if moe else _4hh_b.half()
            attn_nw = attn_nw if attn_nw is None else attn_nw.half()
            attn_nb = attn_nb if attn_nb is None else attn_nb.half()
            input_nw = input_nw.half()
            input_nb = input_nb.half()

        if moe and moe_type == 'residual' and fp16:
            _res_h4h_b = _res_h4h_b.half()
            _res_4hh_b = _res_4hh_b.half()
            _res_h4h_w = _res_h4h_w.half()
            _res_4hh_w = _res_4hh_w.half()
            _res_coef = _res_coef.half()

        #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group)

        if inference:
            if moe:
                ep_world_size = dist.get_world_size()
                local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size

                transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig(
                    hidden_size=hidden_size,
                    heads=num_attention_heads,
                    layer_norm_eps=config.layer_norm_eps if hasattr(
                        config,
                        'layer_norm_eps') else 1e-12,
                    fp16=fp16,
                    pre_layer_norm=policy.pre_attn_norm,
                    mp_size=mp_size,
                    q_int8=quantize,
                    moe_experts=local_ep_size,
                    global_experts=num_experts,
                    mlp_type=moe_type)
            else:
                rotary_dim = config.rotary_dim if hasattr(config, 'rotary_dim') else child.attention.rotary_ndims \
                                            if hasattr(child, 'attention') and hasattr(child.attention,'rotary_ndims') else -1
                bigscience_bloom = policy_cls is BLOOMLayerPolicy
                transformer_config = transformer_inference.DeepSpeedInferenceConfig(
                    hidden_size=hidden_size,
                    heads=num_attention_heads,
                    layer_norm_eps=config.layer_norm_eps if hasattr(
                        config,
                        'layer_norm_eps') else
                    (config.layer_norm_epsilon
                     if hasattr(config,
                                'layer_norm_epsilon') else config.layernorm_epsilon
                     if hasattr(config,
                                'layernorm_epsilon') else 1.0e-12),
                    fp16=fp16,
                    pre_layer_norm=policy.pre_attn_norm,
                    mp_size=mp_size,
                    q_int8=quantize,
                    return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)),
                    triangular_masking=(policy_cls is not HFBertLayerPolicy),
                    local_attention=((config.attention_layers[layer_id] == "local")
                                     if hasattr(config,
                                                'attention_layers') else False),
                    window_size=(config.window_size if hasattr(config,
                                                               'window_size') else 1),
                    rotary_dim=rotary_dim,
                    mlp_after_attn=(rotary_dim is None or rotary_dim < 0),
                    mlp_act_func_type=policy.mlp_act_func_type,
                    training_mp_size=training_mp_size,
                    bigscience_bloom=bigscience_bloom)

            if quantize and quantize_settings is not None:
                (quantization_scales,
                 merge_count,
                 mlp_extra_grouping,
                 quantize_groups) = quantize_settings
                if moe:
                    new_module = transformer_inference.DeepSpeedMoEInference(
                        transformer_config,
                        mp_group=mp_group,
                        ep_group=None if ep_group is None else ep_group[num_experts],
                        expert_mp_group=None
                        if expert_mp_group is None else expert_mp_group[num_experts],
                        quantize_scales=quantization_scales[layer_id],
                        quantize_groups=quantize_groups,
                        merge_count=merge_count,
                        mlp_extra_grouping=mlp_extra_grouping,
                        qkv_merging=(policy_cls is HFBertLayerPolicy))

                else:
                    new_module = transformer_inference.DeepSpeedTransformerInference(
                        transformer_config,
                        mp_group=mp_group,
                        quantize_scales=quantization_scales[layer_id],
                        quantize_groups=quantize_groups,
                        merge_count=merge_count,
                        mlp_extra_grouping=mlp_extra_grouping,
                        qkv_merging=(policy_cls is HFBertLayerPolicy))

                if quantize and qkvw.dtype != torch.int8:
                    quantize_bits = 8
                    quantizer = WeightQuantization()
                    if policy_cls is HFBertLayerPolicy:
                        data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3)
                    else:
                        data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups)
                    qkvw.data.copy_(data_quantized)
                    qkvw.data = qkvw.data.to(torch.int8)
            else:

                if moe:
                    new_module = transformer_inference.DeepSpeedMoEInference(
                        transformer_config,
                        mp_group=mp_group,
                        ep_group=None if ep_group is None else ep_group[num_experts],
                        expert_mp_group=None
                        if expert_mp_group is None else expert_mp_group[num_experts],
                    )

                else:
                    new_module = transformer_inference.DeepSpeedTransformerInference(
                        transformer_config,
                        mp_group=mp_group,
                    )
            new_module.config.scale_attention = scale_attention

            # we want the weights in [input, output] shape
            # linear layer is created with [input, output] shape
            # transpose it here to reduce inference cost!
            def transpose(data):
                # temp move to cpu to avoid requiring extra GPU memory during the reshape
                data = data.to('cpu')
                data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
                data = data.reshape(data.shape[-1], data.shape[-2])
                data.to(torch.cuda.current_device())
                return data

            attn_block = new_module.attention
            mpl_block = new_module.mlp

            if attn_linear_layer:
                if qkvw.numel() == 0 or qkvw.is_meta:
                    if qkvw.is_meta or qkvw.ds_tensor.numel(
                    ) < attn_block.attn_qkvw.numel():
                        pass
                    else:
                        with GatheredParameters([qkvw,
                                                 dense_w,
                                                 qkvb,
                                                 dense_b],
                                                modifier_rank=0):
                            qkvw = transpose(qkvw.data)
                            dense_w = transpose(dense_w.data)
                            qkvb = qkvb.data
                            dense_b = dense_b.data
                else:
                    qkvw.data = transpose(qkvw.data)
                    dense_w.data = transpose(dense_w.data)

            def _transpose(x):
                num_attention_heads_per_partition = transformer_config.heads // transformer_config.mp_size
                attention_head_size = x.shape[-1] // num_attention_heads_per_partition
                new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition,
                                               attention_head_size)
                x_1 = x.view(*new_x_shape)
                (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1))
                if len(q.shape) > 2:
                    return torch.cat((q.reshape(q.shape[0],
                                                -1),
                                      k.reshape(q.shape[0],
                                                -1),
                                      v.reshape(q.shape[0],
                                                -1)),
                                     dim=-1).reshape(x.shape)
                else:
                    return torch.cat((q.reshape(-1),
                                      k.reshape(-1),
                                      v.reshape(-1)),
                                     dim=-1).reshape(x.shape)

            if megatron_v2:
                new_module.config.rotate_half = True
                new_module.config.rotate_every_two = False

                # Note: this part needs to be added for BLOOM architecture
                qkvw = torch.nn.parameter.Parameter(_transpose(qkvw).contiguous())
                qkvb = torch.nn.parameter.Parameter(_transpose(qkvb).contiguous())

            # NOTE: This part caused instability in the multi-GPU inference!
            # TODO: This needs to be incorporated in the kernels.
            #dense_b = dense_b if dense_b is None else dense_b * (
            #    transformer_config.training_mp_size / transformer_config.mp_size)
            #_4hh_b = _4hh_b * (transformer_config.training_mp_size /
            #                   transformer_config.mp_size)

            if mlp_linear_layer:
                if not moe and (_4hh_w.numel() == 0 or _4hh_w.is_meta):
                    if _4hh_w.is_meta or _4hh_w.ds_tensor.numel(
                    ) < mpl_block.inter_w.numel():
                        pass
                    else:
                        with GatheredParameters([_h4h_w,
                                                 _4hh_w,
                                                 _4hh_b,
                                                 _h4h_b],
                                                modifier_rank=0):
                            _h4h_w = transpose(_h4h_w.data)
                            _4hh_w = transpose(_4hh_w.data)
                            _h4h_b = _h4h_b.data
                            _4hh_b = _4hh_b.data
                else:
                    _h4h_w = [transpose(moe_w1.data)
                              for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data)
                    _4hh_w = [transpose(moe_w1.data)
                              for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data)

            if moe and moe_type == 'residual':
                _res_h4h_w.data = transpose(_res_h4h_w.data)
                _res_4hh_w.data = transpose(_res_4hh_w.data)
                _res_coef.data = transpose(_res_coef.data)

            if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta:
                if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel():
                    pass
                else:
                    with GatheredParameters([
                            attn_block.attn_qkvw,
                            attn_block.attn_qkvb,
                            attn_block.attn_ow,
                            attn_block.attn_ob
                    ],
                                            modifier_rank=0):
                        attn_block.attn_qkvw = mp_replace.copy(
                            attn_block.attn_qkvw,
                            qkvw)
                        attn_block.attn_qkvb = mp_replace.copy(
                            attn_block.attn_qkvb,
                            qkvb)

                        attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
                        attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)
            else:
                if bigscience_bloom:
                    attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw)
                    attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb)
                else:
                    attn_block.attn_qkvw = mp_replace.qkv_copy(
                        attn_block.attn_qkvw,
                        qkvw)
                    attn_block.attn_qkvb = mp_replace.qkv_copy(
                        attn_block.attn_qkvb,
                        qkvb)

                attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
                attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)

            if moe:
                gpu_index = dist.get_rank()
                gpu_index = 0
                for ep_index in range(local_ep_size):
                    mpl_block[ep_index].inter_w.data = _h4h_w[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].inter_b.data = _h4h_b[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].output_w.data = _4hh_w[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].output_b.data = _4hh_b[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                new_module.attn_nw.data = attn_nw.to(torch.cuda.current_device())
                new_module.attn_nb.data = attn_nb.to(torch.cuda.current_device())
                if moe_type == 'residual':
                    new_module.res_mlp.inter_w.data = _res_h4h_w.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.inter_b.data = _res_h4h_b.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.output_w.data = _res_4hh_w.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.output_b.data = _res_4hh_b.to(
                        torch.cuda.current_device())
                    new_module.res_coef.data = _res_coef.to(torch.cuda.current_device())
            else:

                if _4hh_w.numel() == 0 or _4hh_w.is_meta:
                    if _4hh_w.is_meta or _4hh_w.ds_tensor.numel(
                    ) < mpl_block.inter_w.numel():
                        pass
                    else:
                        with GatheredParameters([_h4h_w,
                                                 _4hh_w,
                                                 _4hh_w,
                                                 _4hh_b],
                                                modifier_rank=0):
                            mpl_block.inter_w = mp_replace.copy(
                                mpl_block.inter_w,
                                _h4h_w)
                            mpl_block.inter_b = mp_replace.copy(
                                mpl_block.inter_b,
                                _h4h_b)
                            mpl_block.output_w = mp_replace.copy(
                                mpl_block.output_w,
                                _4hh_w)
                            mpl_block.output_b = mp_replace.copy(
                                mpl_block.output_b,
                                _4hh_b)
                else:
                    mpl_block.inter_w = mp_replace.copy(mpl_block.inter_w, _h4h_w)
                    mpl_block.inter_b = mp_replace.copy(mpl_block.inter_b, _h4h_b)
                    mpl_block.output_w = mp_replace.copy(mpl_block.output_w, _4hh_w)
                    mpl_block.output_b = mp_replace.copy(mpl_block.output_b, _4hh_b)

                if attn_nw is None:
                    new_module.mlp.attn_nw = attn_nw
                    new_module.mlp.attn_nb = attn_nb
                else:
                    if attn_nw.is_meta or attn_nw.numel() == 0:
                        if attn_nw.is_meta or attn_nw.ds_tensor.numel(
                        ) < new_module.mlp.attn_nw.numel():
                            pass
                        else:
                            with GatheredParameters([attn_nw, attn_nb], modifier_rank=0):
                                new_module.mlp.attn_nw.data.copy_(
                                    attn_nw.to(torch.cuda.current_device()))
                                new_module.mlp.attn_nb.data.copy_(
                                    attn_nb.to(torch.cuda.current_device()))
                    else:
                        new_module.mlp.attn_nw.data.copy_(
                            attn_nw.to(torch.cuda.current_device()))
                        new_module.mlp.attn_nb.data.copy_(
                            attn_nb.to(torch.cuda.current_device()))

            if input_nw.is_meta or input_nw.numel() == 0:
                if input_nw.is_meta or input_nw.ds_tensor.numel(
                ) < new_module.norm_w.numel():
                    pass
                else:
                    with GatheredParameters([input_nw, input_nb], modifier_rank=0):
                        new_module.norm_w.data.copy_(
                            input_nw.to(torch.cuda.current_device()))
                        new_module.norm_b.data.copy_(
                            input_nb.to(torch.cuda.current_device()))
            else:
                new_module.norm_w.data.copy_(input_nw.to(torch.cuda.current_device()))
                new_module.norm_b.data.copy_(input_nb.to(torch.cuda.current_device()))
        else:
            transformer_config = deepspeed.DeepSpeedTransformerConfig(
                batch_size=micro_batch_size if micro_batch_size > 0 else 1,
                hidden_size=config.hidden_size,
                heads=config.num_attention_heads,
                attn_dropout_ratio=config.attention_probs_dropout_prob,
                hidden_dropout_ratio=config.hidden_dropout_prob,
                num_hidden_layers=config.num_hidden_layers,
                initializer_range=config.initializer_range,
                layer_norm_eps=config.layer_norm_eps if hasattr(
                    config,
                    'layer_norm_eps') else 1e-12,
                seed=seed,
                fp16=fp16,
                pre_layer_norm=policy.pre_attn_norm,
                return_tuple=return_tuple,
                local_rank=local_rank,
                stochastic_mode=stochastic_mode,
                normalize_invertible=True,
                training=training)
            new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
            new_module.attn_qkvw.data = qkvw
            new_module.attn_qkvb.data = qkvb
            new_module.attn_ow.data = dense_w
            new_module.attn_ob.data = dense_b

            new_module.attn_nw.data = attn_nw
            new_module.attn_nb.data = attn_nb
            new_module.norm_w.data = input_nw
            new_module.norm_b.data = input_nb

            new_module.inter_w.data = _h4h_w
            new_module.inter_b.data = _h4h_b
            new_module.output_w.data = _4hh_w
            new_module.output_b.data = _4hh_b
        return new_module

    def replace_wo_policy(module, all_reduce_linears):
        def _replace(child, name, conv_linear_layer):
            mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
            z_inference = (len(list(child.parameters())) > 0) and (list(
                child.parameters())[0].numel() == 0)
            if z_inference:
                weight_shape = child.weight.ds_shape
            else:
                weight_shape = child.weight.shape
            if name in all_reduce_linears:
                new_weight = torch.empty((
                    weight_shape[1] if conv_linear_layer else weight_shape[0],
                    (weight_shape[0] if conv_linear_layer else weight_shape[1]) //
                    mp_size,
                ),
                                         device=child.weight.device,
                                         dtype=child.weight.dtype)
                if z_inference:
                    with deepspeed.zero.GatheredParameters(child.weight,
                                                           modifier_rank=0):
                        data = child.weight.data.to(new_weight.device)
                        if conv_linear_layer:
                            data = data.transpose(-1, -2).contiguous()
                        data = mp_replace.copy(new_weight, data)
                    child.weight.ds_tensor = torch.empty(1)
                else:
                    if conv_linear_layer:
                        child.weight.data = child.weight.data.transpose(-1,
                                                                        -2).contiguous()
                    data = mp_replace.copy(new_weight, child.weight.data)
                new_bias = torch.empty((weight_shape[0]),
                                       device=child.weight.device,
                                       dtype=child.weight.dtype)
                if z_inference:
                    with deepspeed.zero.GatheredParameters(child.bias, modifier_rank=0):
                        new_bias.data.copy_(child.bias.data)
                elif child.bias is not None:
                    new_bias.data.copy_(child.bias.data)
                return LinearAllreduce(data, child.bias if child.bias is None else \
                            torch.nn.parameter.Parameter(new_bias.to(torch.cuda.current_device())), mp_group)
            else:
                new_weight = torch.empty((
                    (weight_shape[1] if conv_linear_layer else weight_shape[0]) //
                    mp_size,
                    weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1],
                ),
                                         device=child.weight.device,
                                         dtype=child.weight.dtype)
                if z_inference:
                    with deepspeed.zero.GatheredParameters(child.weight,
                                                           modifier_rank=0):
                        data = child.weight.data.to(new_weight.device)
                        if conv_linear_layer:
                            data = data.transpose(-1, -2).contiguous()
                        data = mp_replace.copy(new_weight, data)
                    child.weight.ds_tensor = torch.empty(1)
                else:
                    if conv_linear_layer:
                        child.weight.data = child.weight.data.transpose(-1,
                                                                        -2).contiguous()
                    data = mp_replace.copy(new_weight, child.weight.data)

                new_bias = torch.empty((weight_shape[0] // mp_size),
                                       device=child.weight.device,
                                       dtype=child.weight.dtype)
                if z_inference:
                    with deepspeed.zero.GatheredParameters(child.bias, modifier_rank=0):
                        bias_data = None if child.bias is None else mp_replace.copy(
                            new_bias,
                            child.bias.data).to(torch.cuda.current_device())
                else:
                    bias_data = None if child.bias is None else mp_replace.copy(
                        new_bias,
                        child.bias.data).to(torch.cuda.current_device())
                return LinearLayer(weight=data.to(torch.cuda.current_device()),
                                   bias=bias_data)

        def _slice_embedding(child, name, conv_linear_layer):
            mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
            new_weight = torch.empty((child.weight.shape[0],
                                      child.weight.shape[1] // mp_size),
                                     device=child.weight.device,
                                     dtype=child.weight.dtype)
            data = mp_replace.copy(new_weight,
                                   child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \
                                   child.weight.data)
            new_embedding = nn.Embedding(child.weight.shape[0],
                                         child.weight.shape[1] // mp_size)
            new_embedding.weight.data.copy_(data)
            return new_embedding

        def update_mp_params(child):
            if hasattr(child, 'n_heads'):
                child.n_heads = child.n_heads // mp_size
            if hasattr(child, 'inner_dim'):
                child.inner_dim = child.inner_dim // mp_size
            if hasattr(child, 'num_heads'):
                child.num_heads = child.num_heads // mp_size
            if hasattr(child, 'num_attention_heads'):
                child.num_attention_heads = child.num_attention_heads // mp_size
            if hasattr(child, 'all_head_size'):
                child.all_head_size = child.all_head_size // mp_size
            if hasattr(child, 'embed_dim'):
                child.embed_dim = child.embed_dim // mp_size
            if hasattr(child, 'hidden_size'):
                child.hidden_size = child.hidden_size // mp_size

        conv_linear_layer = False
        if linear_layer_setting is not None:
            linear_policies = {linear_layer_setting[0]: _replace}
            if len(linear_layer_setting) == 2:
                linear_policies.update({linear_layer_setting[1]: _slice_embedding})
        else:
            if orig_layer_impl is HFGPT2LayerPolicy._orig_layer_class:
                try:
                    import transformers
                    conv_linear_layer = True
                    linear_policies = {transformers.model_utils.Conv1D: _replace}
                except ImportError:
                    linear_policies = {nn.Linear: _replace}
            else:
                linear_policies = {nn.Linear: _replace, nn.Embedding: _slice_embedding}

        def _replace_module(r_module, prev_name=''):
            for name, child in r_module.named_children():
                if child.__class__ in linear_policies:
                    setattr(
                        r_module,
                        name,
                        linear_policies[child.__class__](child,
                                                         prev_name + '.' + name,
                                                         conv_linear_layer))
                else:
                    update_mp_params(child)
                    _replace_module(child, name)
            return r_module

        return _replace_module(module)

    def replace_fn(child, _policy, layer_id=0):
        if training:
            # copy relevant state from child -> new module
            new_module = replace_with_policy(child, _policy, triangular_masking)

        else:
            # copy relevant state from child -> new module
            if replace_with_kernel_inject:
                new_module = replace_with_policy(child,
                                                 _policy,
                                                 triangular_masking,
                                                 inference=True,
                                                 layer_id=layer_id)
            else:
                new_module = replace_wo_policy(child, _policy)

        return new_module

    replaced_module = replace_module(model=model,
                                     orig_class=orig_layer_impl,
                                     replace_fn=replace_fn,
                                     _replace_policy=policy)

    world_size = dist.get_world_size() if dist.is_initialized() else 1
    rank = dist.get_rank() if dist.is_initialized() else 0
    if checkpoint_dict is not None:
        start_time = time.time()
        checkpoint = checkpoint_dict['checkpoints']
        ckpt_type = checkpoint_dict.get('parallelization', 'pp')
        ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size)
        base_dir = checkpoint_dict.get('base_dir', '')

        if ckpt_type == 'pp':
            pbar = tqdm.tqdm(total=len(checkpoint),
                             desc=f"Loading {len(checkpoint)} checkpoint shards")
            for i in range(len(checkpoint)):
                if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
                    pbar.update(1)
                sd = torch.load(checkpoint[i], map_location='cpu')
                load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type)
        else:
            num_checkpoints = len(checkpoint) // ckpt_mp_size
            assert world_size >= ckpt_mp_size,\
                "Currently, merging checkpoints is not supported (when world_size is smaller than #checkpoints)!"
            checkpoint_stride = world_size // ckpt_mp_size
            if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
                pbar = tqdm.tqdm(total=num_checkpoints,
                                 desc=f"Loading {num_checkpoints} checkpoint shards")
            for i in range(num_checkpoints):
                if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
                    pbar.update(1)

                ckpt_index = i * ckpt_mp_size + (rank // checkpoint_stride)
                ckpt_file = os.path.join(
                    base_dir,
                    checkpoint[ckpt_index]) if base_dir else checkpoint[ckpt_index]
                sd = torch.load(ckpt_file, map_location='cpu')
                load_model_with_checkpoint(replaced_module,
                                           sd,
                                           mp_replace,
                                           ckpt_type,
                                           rank % (world_size // ckpt_mp_size))
        print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")

    if save_mp_checkpoint_path is not None:
        from collections import OrderedDict
        import json

        if checkpoint_dict is None:
            ckpt_name = "ds_model"
            try:
                from transformers.models.bloom.modeling_bloom import BloomForCausalLM
                if isinstance(model, BloomForCausalLM):
                    ckpt_name = "bloom"
            except ImportError:
                ckpt_name = "ds_model"
        else:
            ckpt_name = checkpoint_dict['type']
        if dist.is_initialized():
            dist.barrier()
        transformer_name = get_transformer_name(replaced_module)
        non_tp_ckpt_name = f'{ckpt_name}-non-tp.pt'
        ckpt_files = [non_tp_ckpt_name] * world_size
        os.makedirs(save_mp_checkpoint_path, exist_ok=True)
        if not dist.is_initialized() or dist.get_rank() == 0:
            print("Saving tp-sharded checkpoints")
            torch.save(
                OrderedDict({
                    k: v
                    for k,
                    v in dict(replaced_module.state_dict()).items()
                    if transformer_name not in k
                }),
                f'{save_mp_checkpoint_path}/{non_tp_ckpt_name}')
            ckpt_files += [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)]
            config = json.dumps({
                'type': ckpt_name,
                'base_dir': f'{save_mp_checkpoint_path}',
                'checkpoints': ckpt_files,
                'version': 1.0,
                'parallelization': 'tp',
                'mp_size': world_size
            })
            with open(f"{save_mp_checkpoint_path}/{ckpt_name}_ds-inference_config.json",
                      "w") as cfg:
                cfg.write(config)
        torch.save(
            OrderedDict({
                k: v
                for k,
                v in dict(replaced_module.state_dict()).items() if transformer_name in k
            }),
            f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{rank:0>2d}.pt')

    return replaced_module
Esempio n. 17
0
    def __init__(self,
                 params,
                 deepspeed=None,
                 lr=1e-3,
                 freeze_step=100000,
                 bias_correction=True,
                 betas=(0.9, 0.999),
                 eps=1e-8,
                 eps_inside_sqrt=False,
                 weight_decay=0.,
                 max_grad_norm=0.,
                 max_coeff=10.0,
                 min_coeff=0.01,
                 amsgrad=False,
                 cuda_aware=False,
                 comm_backend_name='nccl',
                 coeff_beta=0.9,
                 factor_max=4.0,
                 factor_min=0.5,
                 factor_threshold=0.1):

        if amsgrad:
            raise RuntimeError(
                '1-bit Lamb does not support the AMSGrad variant.')

        defaults = dict(lr=lr,
                        bias_correction=bias_correction,
                        betas=betas,
                        eps=eps,
                        weight_decay=weight_decay,
                        max_grad_norm=max_grad_norm,
                        max_coeff=max_coeff,
                        min_coeff=min_coeff)

        super(OnebitLamb, self).__init__(params, defaults)
        self.eps_mode = 0 if eps_inside_sqrt else 1
        assert (dist.is_initialized())

        self.deepspeed = deepspeed
        self.lamb_freeze_key = False
        self.initialize = False
        self.freeze_step = freeze_step
        self.cuda_aware = cuda_aware
        self.coeff_beta = coeff_beta
        self.factor_max = factor_max
        self.factor_min = factor_min
        self.factor_threshold = factor_threshold
        self.using_pipeline = False

        self.comm_backend_name = comm_backend_name

        # Empty initializer. Set handle based on the comm backend as follows.
        self.comm_backend_handle = None

        if self.comm_backend_name == 'nccl':
            TORCH_MAJOR = int(torch.__version__.split('.')[0])
            TORCH_MINOR = int(torch.__version__.split('.')[1])
            assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
            assert dist.is_initialized(
            ) == True, "Please initialize the torch distributed backend."
            from deepspeed.runtime.comm.nccl import NcclBackend
            self.using_pipeline = hasattr(
                self.deepspeed, 'pipeline_enable_backward_allreduce')
            self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)

        elif self.comm_backend_name == 'mpi':
            from deepspeed.runtime.comm.mpi import MpiBackend
            self.comm_backend_handle = MpiBackend(cuda_aware)

        self.size = self.comm_backend_handle.size

        self.divider = int(self.size * 8 / np.gcd(self.size, 8))

        self.exp_avg_flat = []
        self.dummy_exp_avg = {}
        self.corrected_tensor_sizes = []
        self.server_chunk_sizes = []
        self.worker_errors = []
        self.server_errors = []

        self.lamb_coeffs = []
Esempio n. 18
0
def test_scattered_init_dist():
    setup_serial_env()
    assert not dist.is_initialized()
    with deepspeed.zero.Init():
        assert dist.is_initialized()
Esempio n. 19
0
        def compute_attention(qkv_out, input_mask):
            no_masking = input_mask is None

            head_size = (qkv_out.shape[-1] // 3 //
                         num_attention_heads_per_partition)
            if no_masking:
                input_mask = torch.empty(1)
            if merge_count > 0 and config.q_int8:
                split_dim = (qkv_out.dim() - 1)
                qkv_split = torch.split(qkv_out, (qkv_out.shape[-1] //
                                                  (2**merge_count)),
                                        dim=split_dim)
                qkv_split = [
                    torch.split(s, (s.shape[-1] // 3), dim=split_dim)
                    for s in qkv_split
                ]
                (mixed_query, key_layer, value_layer) = [
                    torch.cat([s[i] for s in qkv_split], axis=-1)
                    for i in range(len(qkv_split[0]))
                ]

                if config.rotary_dim > 0:
                    mixed_query, key_layer = inference_cuda_module.apply_rotary_pos_emb(
                        mixed_query, key_layer, config.rotary_dim,
                        0 if layer_past is None else layer_past[0].shape[-2],
                        num_attention_heads_per_partition, config.rotate_half,
                        config.rotate_every_two)
                if layer_past is not None:
                    past_key, past_value = layer_past
                    key_layer = torch.cat(
                        (past_key.type_as(key_layer), key_layer), dim=-2)
                    value_layer = torch.cat(
                        (past_value.type_as(value_layer), value_layer), dim=-2)
                presents = (key_layer, value_layer)
                mixed_query = _transpose_for_scores(mixed_query, False, True)
                key_layer = _transpose_for_scores(key_layer, True, True) / (
                    norm_factor if config.scale_attention else 1.0)
                value_layer = _transpose_for_scores(value_layer, False, True)
                if layer_past is None:
                    attn_key_value = score_context_func(
                        mixed_query,
                        key_layer,
                        torch.empty(1),
                        ((1 - input_mask).half() * minus_inf)
                        if input_mask.dtype == torch.int64 else input_mask,
                        value_layer,
                        torch.empty(1),
                        num_attention_heads_per_partition,
                        (1 / norm_factor if config.scale_attention else 1.0),
                        (not unfused_mode),  # noqa: F821
                        config.triangular_masking,
                        config.local_attention,
                        config.window_size,
                        no_masking)
                else:
                    attn_key_value = score_context_func(
                        mixed_query,
                        (key_layer if unfused_mode else
                         past_key.type_as(key_layer)),  # noqa: F821
                        key_layer,
                        ((1 - input_mask).half() * minus_inf)
                        if input_mask.dtype == torch.int64 else input_mask,
                        (value_layer if unfused_mode else
                         past_value.type_as(value_layer)),  # noqa: F821
                        value_layer,
                        num_attention_heads_per_partition,
                        (1 / norm_factor if config.scale_attention else 1.0),
                        (not unfused_mode),  # noqa: F821
                        config.triangular_masking,
                        config.local_attention,
                        config.window_size,
                        no_masking)
                if unfused_mode:  # noqa: F821
                    context_layer, _, _ = attn_key_value
                else:
                    context_layer, key_layer, value_layer = attn_key_value

                # Transpose Context
                context_layer = _transpose_for_context(context_layer)

                return context_layer, presents[0], presents[
                    1]  # atten_output, key_layer, value_layer
            else:
                # Note: This modification is added for the BLOOM-176B model and will be removed later!
                if config.bigscience_bloom:
                    context_layer, presents = backup_attention(
                        qkv_out, layer_past, alibi, input_mask, norm_factor)
                    return context_layer, presents[0], presents[
                        1]  #key_layer, value_layer
                else:
                    if alibi is not None:
                        batch_heads = qkv_out.shape[
                            0] * num_attention_heads_per_partition
                        offset = dist.get_rank(
                        ) * batch_heads if dist.is_initialized() else 0
                        sliced_alibi = alibi[offset:batch_heads + offset, :, :]

                    attn_key_value = score_context_func(
                        qkv_out,
                        ((1 - input_mask).to(qkv_out.dype) * minus_inf)
                        if input_mask.dtype == torch.int64 else input_mask,
                        config.rotary_dim, config.rotate_half,
                        config.rotate_every_two,
                        num_attention_heads_per_partition,
                        (1 / norm_factor if config.scale_attention else 1.0),
                        config.triangular_masking, config.local_attention,
                        config.window_size, no_masking, config.layer_id,
                        DeepSpeedTransformerInference.layer_id,
                        sliced_alibi if alibi is not None else torch.empty(1))
                    context_layer, key_layer, value_layer = attn_key_value
                    return context_layer, key_layer, value_layer
Esempio n. 20
0
 def test(self):
     assert dist.is_initialized()
     assert dist.get_world_size() == 3
     assert dist.get_rank() < 3