Пример #1
0
        def __init__(self) -> None:
            super().__init__()

            self.modulelist = ModuleList(
                EltwiseMultiplicationModule(weight=Parameter(
                    torch.empty((param_sz, ), dtype=torch.float32)))
                for _ in range(n_layers))

            for layer_num, module in enumerate(self.modulelist):
                with deepspeed.zero.GatheredParameters(module.weight,
                                                       modifier_rank=0):
                    param: Parameter = module.weight
                    partition_sz = math.ceil(param.numel() /
                                             dist.get_world_size())
                    offset = 0
                    for rank in range(dist.get_world_size()):
                        with torch.no_grad():
                            param[offset:offset + partition_sz].fill_(
                                2 * layer_num * rank)
                        offset += partition_sz
Пример #2
0
def _gather(input_):
    """Gather tensors and concatinate along the last dimension."""
    group = g_mpu.get_model_parallel_group()

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # Size and dimension.
    last_dim = input_.dim() - 1
    rank = dist.get_rank(group=group)
    world_size = dist.get_world_size(group=group)

    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    tensor_list[rank] = input_
    dist.all_gather(tensor_list, input_, group=group)

    # Note: torch.cat already creates a contiguous tensor.
    output = torch.cat(tensor_list, dim=last_dim).contiguous()

    return output
Пример #3
0
def _reduce(input_):
    """All-reduce the the input tensor across model parallel group."""
    group = g_mpu.get_model_parallel_group()

    # Bypass the function if we are using only 1 GPU.
    if dist.get_world_size(group=group) == 1:
        return input_

    # All-reduce.
    dist.all_reduce(input_, group=group)

    return input_
Пример #4
0
    def _create_ep_parallel_group(self, moe_experts):
        # Call the init process
        self.ep_group = {}
        self.expert_mp_group = {}
        moe_experts = moe_experts if type(moe_experts) is list else [moe_experts]
        for e in moe_experts:
            self.ep_group.update({e: None})
            self.expert_mp_group.update({e: None})
        for moe_ep_size in self.ep_group.keys():
            num_ep_groups = dist.get_world_size() // moe_ep_size
            for i in range(num_ep_groups):
                ep_cnt = i * moe_ep_size
                size = dist.get_world_size(
                ) if moe_ep_size > dist.get_world_size() else moe_ep_size
                ranks = list(range(ep_cnt, ep_cnt + size))
                _ep_group = dist.new_group(ranks)
                if dist.get_rank() in ranks:
                    self.ep_group.update({moe_ep_size: _ep_group})

            if dist.get_world_size() > moe_ep_size:
                num_expert_mp_groups = dist.get_world_size() // num_ep_groups
                expert_mp_size = dist.get_world_size() // moe_ep_size
                for i in range(num_expert_mp_groups):
                    expert_mp_comm_ranks = [
                        i + nr * moe_ep_size for nr in range(expert_mp_size)
                    ]
                    _expert_mp_group = dist.new_group(expert_mp_comm_ranks)
                    if dist.get_rank() in expert_mp_comm_ranks:
                        self.expert_mp_group.update({moe_ep_size: _expert_mp_group})
Пример #5
0
def _create_expert_and_data_parallel(expert_parallel_size_):
    """
        Create expert and data parallel groups.

        Note: Caller of this function is responsible to check if the groups already exist.

        Example - E + D parallel
        world_size = 16
        expert_parallel_size = 2 # number of experts in same group
        expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params
        expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all
        data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE
    """
    assert dist.is_initialized()

    log_dist(
        f'Creating expert and data parallel groups with size {expert_parallel_size_}',
        ranks=[0])
    world_size = dist.get_world_size()
    rank = dist.get_rank()

    _ensure_divisibility(world_size, expert_parallel_size_)

    group_name = f"ep_size_{expert_parallel_size_}"

    # Build the expert data parallel groups.
    global _EXPERT_DATA_PARALLEL_GROUP

    # Only create group if it does not already exist
    if group_name not in _EXPERT_DATA_PARALLEL_GROUP:
        for i in range(expert_parallel_size_):
            ranks = range(i, world_size, expert_parallel_size_)
            group = dist.new_group(ranks)
            log_dist(
                f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}',
                [0])
            if i == (rank % expert_parallel_size_):
                _EXPERT_DATA_PARALLEL_GROUP[group_name] = group

    # Build the expert parallel groups.
    global _EXPERT_PARALLEL_GROUP

    # Only create group if it does not already exist
    if group_name not in _EXPERT_PARALLEL_GROUP:
        for i in range(world_size // expert_parallel_size_):
            ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
            group = dist.new_group(ranks)
            log_dist(
                f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}',
                [0])
            if i == (rank // expert_parallel_size_):
                _EXPERT_PARALLEL_GROUP[group_name] = group
Пример #6
0
def _clone_world_group():
    """Create a clone of the world group
        Note: We need to clone the dist world group because we
        use dist.get_global_rank() utility function in DeepSpeed at many places.
        As that function does not work on dist.group.WORLD, we
        need to keep a clone of it.
    """
    assert dist.is_initialized(), "dist is not initialized"
    global _WORLD_GROUP
    if _WORLD_GROUP is None:
        # If not cloned already, clone the world group
        _WORLD_GROUP = dist.new_group(ranks=range(dist.get_world_size()))
    return _WORLD_GROUP
Пример #7
0
    def _test_batch_config(num_ranks, batch, micro_batch, gas, success):
        assert dist.get_world_size() == num_ranks, \
        'The test assumes a world size of f{num_ranks}'

        ds_batch_config = get_test_path('ds_batch_config.json')
        ds_config = DeepSpeedConfig(ds_batch_config)

        #test cases when all parameters are provided
        status = _run_batch_config(ds_config,
                                   train_batch=batch,
                                   micro_batch=micro_batch,
                                   gas=gas)
        _batch_assert(status, ds_config, batch, micro_batch, gas, success)

        #test cases when two out of three parameters are provided
        status = _run_batch_config(ds_config,
                                   train_batch=batch,
                                   micro_batch=micro_batch)
        _batch_assert(status, ds_config, batch, micro_batch, gas, success)

        if success:
            #when gas is provided with one more parameter
            status = _run_batch_config(ds_config, train_batch=batch, gas=gas)
            _batch_assert(status, ds_config, batch, micro_batch, gas, success)

            status = _run_batch_config(ds_config,
                                       micro_batch=micro_batch,
                                       gas=gas)
            _batch_assert(status, ds_config, batch, micro_batch, gas, success)

            #test the case when only micro_batch or train_batch is provided
            if gas == 1:
                status = _run_batch_config(ds_config, micro_batch=micro_batch)
                _batch_assert(status, ds_config, batch, micro_batch, gas,
                              success)

                status = _run_batch_config(ds_config, train_batch=batch)
                _batch_assert(status, ds_config, batch, micro_batch, gas,
                              success)
        else:
            #when only gas is provided
            status = _run_batch_config(ds_config, gas=gas)
            _batch_assert(status, ds_config, batch, micro_batch, gas, success)

            #when gas is provided with something else and gas does not divide batch
            if gas != 1:
                status = _run_batch_config(ds_config,
                                           train_batch=batch,
                                           gas=gas)
                _batch_assert(status, ds_config, batch, micro_batch, gas,
                              success)
Пример #8
0
def torch_sim(a):
    a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
    scale = a.norm() / np.sqrt(a.numel())
    a_compressed = scale * a_sign
    a_sign = None
    worker_error = a - a_compressed
    dist.all_reduce(a_compressed)
    a_compressed.mul_(1 / dist.get_world_size())
    a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(
        2.0)
    a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
    server_scale = [
        chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list
    ]
    a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
    a_server_compressed = torch.cat([
        server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())
    ])
    rank = dist.get_rank()
    server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
    torch.cuda.synchronize()
    dist.barrier()
    return a_server_compressed, worker_error, server_error
Пример #9
0
def _create_model_parallel(model_parallel_size_):
    """
    Initialize model data parallel groups.

    Arguments:
        model_parallel_size: number of GPUs used to parallelize model.

    Returns:
        Tuple of data parallel group and model parallel group

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
    use 2 GPUs to parallelize the model. The present function will
    create 4 model parallel groups and 2 data parallel groups as:
        4 model parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 data parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    log_dist(f'Creating model parallel group with size {model_parallel_size_}',
             ranks=[0])
    # Get world size and rank. Ensure some consistencies.
    assert dist.is_initialized()
    world_size = dist.get_world_size()
    model_parallel_size = min(model_parallel_size_, world_size)
    _ensure_divisibility(world_size, model_parallel_size)
    rank = dist.get_rank()

    _DATA_PARALLEL_GROUP = None
    _MODEL_PARALLEL_GROUP = None
    # Build the data parallel groups.
    for i in range(model_parallel_size):
        ranks = range(i, world_size, model_parallel_size)
        group = dist.new_group(ranks)
        if i == (rank % model_parallel_size):
            _DATA_PARALLEL_GROUP = group

    # Build the model parallel groups.
    for i in range(world_size // model_parallel_size):
        ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
        group = dist.new_group(ranks)
        if i == (rank // model_parallel_size):
            _MODEL_PARALLEL_GROUP = group

    return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP
Пример #10
0
    def forward(ctx,
                input,
                residual,
                residual_norm,
                bias,
                inter_w,
                inter_b,
                attn_nw,
                attn_nb,
                config,
                mp_group,
                output_b,
                output_w,
                q_scales,
                q_groups,
                merge_count,
                mlp_gemm_func,
                fused_gemm_gelu,
                vector_matmul_func,
                bias_residual_func,
                activation_func_type=ActivationFuncType.GELU):

        if config.q_int8:
            (intermediate, residual_add) = inference_cuda_module.mlp_gemm_int8(
                input, residual, bias, inter_w, inter_b, attn_nw, attn_nb,
                config.epsilon, q_scales[2], (q_groups * (2**merge_count)),
                config.pre_layer_norm)
            output = inference_cuda_module.vector_matmul_int8(
                intermediate, output_w, q_scales[3], q_groups, (merge_count))
        else:
            if attn_nw is None:
                output = fused_gemm_gelu(residual_norm, inter_w, inter_b,
                                         output_w, config.epsilon,
                                         config.pre_layer_norm, False)
            else:
                intermediate, residual_add = mlp_gemm_func(
                    input, residual, bias, inter_w, inter_b, attn_nw, attn_nb,
                    config.epsilon, config.pre_layer_norm,
                    config.mlp_after_attn, config.mlp_act_func_type)
                output = vector_matmul_func(intermediate, output_w, False)

        inference_cuda_module.residual_add(
            output, residual if config.pre_layer_norm else residual_add, input,
            output_b, bias if bias is not None else output_b, config.mp_size,
            config.mlp_after_attn, bias is not None, config.pre_layer_norm)
        if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
            dist.all_reduce(output, group=mp_group)
        return output
Пример #11
0
def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
    data_parallel_size = int(dist.get_world_size())
    parameter_parallel_size = parameter_parallel_size or data_parallel_size
    logger.info("data_parallel_size: %s, parameter_parallel_size: %s",
                data_parallel_size,
                parameter_parallel_size)
    assert data_parallel_size % parameter_parallel_size == 0, \
        'world size should be divisible by parameter parallel size'
    rank = dist.get_rank()
    my_group = None
    for i in range(data_parallel_size // parameter_parallel_size):
        ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
        group = dist.new_group(ranks)
        if rank in ranks:
            my_group = group
    return my_group
Пример #12
0
def test_partitioned_tensor_meta():
    world = dist.get_world_size()
    rank = dist.get_rank()

    group = dist.new_group(ranks=list(range(world)))

    rows = world * 7
    cols = 3

    full = torch.rand(rows, cols).cuda()
    dist.broadcast(full, src=0, group=group)
    part = PartitionedTensor(full, group=group)

    my_meta = PartitionedTensor.from_meta(part.to_meta(), part.local_data,
                                          group)
    assert torch.equal(full, my_meta.full())
Пример #13
0
 def _get_norm_with_moe_layers(self, all_groups_norm):
     #all_groups_norm_old = all_groups_norm
     # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
     if self.using_pipeline:
         pg = self.deepspeed.mpu.get_data_parallel_group()
     else:
         pg = groups._get_data_parallel_group()
     scaled_norm = all_groups_norm * 1.0 / float(
         dist.get_world_size(group=pg))
     scaled_norm_tensor = torch.tensor(
         scaled_norm,
         device=self.fp32_groups_flat[0].device,
         dtype=torch.float)
     dist.all_reduce(scaled_norm_tensor, group=pg)
     all_groups_norm = scaled_norm_tensor.item()
     #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
     return all_groups_norm
Пример #14
0
def test_partitioned_tensor():
    world = dist.get_world_size()
    rank = dist.get_rank()

    group = dist.new_group(ranks=list(range(world)))

    rows = world * 4
    cols = 3

    full = torch.rand(rows, cols).cuda()
    dist.broadcast(full, src=0, group=group)
    part = PartitionedTensor(full, group=group)

    assert len(part.local_size()) == 1
    assert part.local_size()[0] * world == full.numel()

    reconstructed = part.full()
    assert torch.equal(full, reconstructed)
Пример #15
0
    def forward(ctx, input, inter_w, inter_b, config, output_b, output_w,
                q_scales, q_groups, merge_count, mp_group, async_op):
        if config.q_int8:
            intermediate = inference_cuda_module.fused_gemm_gelu_int8(
                input, inter_w, inter_b, config.epsilon, q_scales[2],
                (q_groups * (2**merge_count)), config.pre_layer_norm)
            output = inference_cuda_module.vector_matmul_int8(
                intermediate, output_w, q_scales[3], q_groups, (merge_count))
        else:
            mlp_gemm_func = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \
                                    inference_cuda_module.fused_gemm_gelu_fp32

            output = mlp_gemm_func(input, inter_w, inter_b, output_w,
                                   config.epsilon, config.pre_layer_norm,
                                   async_op)
        if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
            dist.all_reduce(output, group=mp_group, async_op=async_op)

        return output + output_b
Пример #16
0
    def test(self):
        param1 = torch.nn.Parameter(torch.Tensor([0]))
        param1.grad = torch.Tensor([1])
        param2 = torch.nn.Parameter(torch.Tensor([0]))
        param2.grad = torch.Tensor([dist.get_rank() + 1])
        # param2 is now MoE parameter
        param2.allreduce = False

        parameters = [param1, param2]

        groups._create_expert_and_data_parallel(2)

        norm = ds_utils.clip_grad_norm_(parameters, max_norm=0.1)
        norm = torch.Tensor([norm]).to(dist.get_rank())

        world_size = dist.get_world_size()
        gathered_norm = [torch.zeros(1).cuda() for i in range(world_size)]

        dist.all_gather(gathered_norm, norm)

        assert gathered_norm[0] == gathered_norm[
            1], "norm at rank 0 does not match the norm at rank 1"
Пример #17
0
    def from_meta(cls, meta, local_part, group, device='cuda'):
        assert meta.dtype == torch.long
        dummy = torch.ones(dist.get_world_size(group=group))
        part_obj = cls(tensor=dummy, group=group)

        meta = meta.tolist()

        # [N, list0, ..., listN-1]
        part_obj.orig_size = meta[1:(1 + meta[0])]
        meta = meta[1 + meta[0]:]

        part_obj.orig_device = device
        part_obj.local_data = local_part.detach()

        part_obj.group = group

        # Partition is encoded like the rowptr of a CSR matrix:
        # [num_parts, rank, 0, part_1, ..., part_num_parts]
        # TODO: support shuffle between different partition granularities
        assert part_obj.num_parts == meta[0]
        assert part_obj.rank == meta[1]
        part_obj.partition = meta[2:]  # length num_parts+1

        return part_obj
Пример #18
0
    def __init__(self, topology=None, process_group=None):
        # TODO use process_group if provided
        self.global_rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        if topology is not None:
            if self.global_rank == 0:
                print('Using topology:', topology)
            self._topo = topology
        else:
            num_pp = 1
            num_dp = 1
            for idx, prime in enumerate(_prime_factors(self.world_size)):
                if idx % 2 == 0:
                    num_pp *= prime
                else:
                    num_dp *= prime
            self._topo = PipeDataParallelTopology(num_dp=num_dp, num_pp=num_pp)
        self.data_parallel_size = max(self._topo.get_dim('data'), 1)
        self.pipe_parallel_size = max(self._topo.get_dim('pipe'), 1)
        self.model_parallel_size = max(self._topo.get_dim('model'), 1)
        self.slice_parallel_size = self.model_parallel_size
        assert self._is_grid_valid(), "Invalid Grid"

        self.stage_id = self.get_stage_id()
        self.data_parallel_id = self.get_data_parallel_id()

        # Create new ProcessGroups for all model parallelism. DeepSpeedLight uses these
        # to detect overflow, etc.
        self.ds_model_proc_group = None
        self.ds_model_rank = -1
        for dp in range(self.data_parallel_size):
            ranks = sorted(self._topo.get_axis_list(axis='data', idx=dp))
            if self.global_rank == 0:
                #print(f'RANK={self.global_rank} building DeepSpeed model group: {ranks}')
                pass
            proc_group = dist.new_group(ranks=ranks)
            if self.global_rank in ranks:
                self.ds_model_proc_group = proc_group
                self.ds_model_world_size = len(ranks)
                self.ds_model_rank = ranks.index(self.global_rank)
        assert self.ds_model_rank > -1
        assert self.ds_model_proc_group is not None

        # Create new ProcessGroup for gradient all-reduces - these are the data parallel groups
        self.dp_group = []
        self.dp_groups = self._topo.get_axis_comm_lists('data')
        for g in self.dp_groups:
            proc_group = dist.new_group(ranks=g)
            if self.global_rank in g:
                self.dp_group = g
                self.dp_proc_group = proc_group

        self.is_first_stage = (self.stage_id == 0)
        self.is_last_stage = (self.stage_id == (self.pipe_parallel_size - 1))

        self.p2p_groups = self._build_p2p_groups()

        # Create new ProcessGroup for pipeline collectives - these are pipe parallel groups
        self.pp_group = []
        self.pp_proc_group = None
        self.pipe_groups = self._topo.get_axis_comm_lists('pipe')
        for ranks in self.pipe_groups:
            if self.global_rank == 0:
                #print(f'RANK={self.global_rank} building pipeline group: {ranks}')
                pass
            proc_group = dist.new_group(ranks=ranks)
            if self.global_rank in ranks:
                self.pp_group = ranks
                self.pp_proc_group = proc_group
        assert self.pp_proc_group is not None

        # Create new ProcessGroup for model (tensor-slicing) collectives

        # Short circuit case without model parallelism.
        # TODO: it would be nice if topology had bcast semantics to avoid this branching
        # case?
        if self.model_parallel_size == 1:
            for group_rank in range(self.world_size):
                group_rank = [group_rank]
                group = dist.new_group(ranks=group_rank)
                if group_rank[0] == self.global_rank:
                    self.slice_group = group_rank
                    self.slice_proc_group = group
            return
        else:
            self.mp_group = []
            self.model_groups = self._topo.get_axis_comm_lists('model')
            for g in self.model_groups:
                proc_group = dist.new_group(ranks=g)
                if self.global_rank in g:
                    self.slice_group = g
                    self.slice_proc_group = proc_group
Пример #19
0
def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
    """Clips gradient norm of an iterable of parameters.

    This has been adapted from Nvidia megatron. We add norm averaging
    to consider MoE params when calculating norm as they will result
    in different norms across different ranks.

    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
    added functionality to handle model parallel parameters. Note that
    the gradients are modified in place.

    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if norm_type == inf:
        total_norm = max(p.grad.data.abs().max() for p in parameters)
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all GPUs.
        if mpu is not None:
            dist.all_reduce(total_norm_cuda,
                            op=dist.ReduceOp.MAX,
                            group=mpu.get_model_parallel_group())
        total_norm = total_norm_cuda[0].item()
    else:
        total_norm = 0
        for p in parameters:
            if mpu is not None:
                if (mpu.get_model_parallel_rank()
                        == 0) or is_model_parallel_parameter(p):
                    param_norm = p.grad.data.norm(norm_type)
                    total_norm += param_norm.item()**norm_type
            else:
                param_norm = p.grad.data.float().norm(norm_type)
                total_norm += param_norm.item()**norm_type

        # Sum across all model parallel GPUs.
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        if mpu is not None:
            dist.all_reduce(total_norm_cuda,
                            op=dist.ReduceOp.SUM,
                            group=mpu.get_model_parallel_group())
        total_norm = total_norm_cuda[0].item()**(1. / norm_type)

    # Need to average total_norm across different GPUs due to the presence of moe params
    pg = groups._get_data_parallel_group()
    scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg))

    scaled_norm_tensor = torch.cuda.FloatTensor([float(scaled_norm)])
    dist.all_reduce(scaled_norm_tensor, group=pg)
    total_norm = scaled_norm_tensor.item()

    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in parameters:
            p.grad.data.mul_(clip_coef)
    return total_norm
Пример #20
0
def reduce_scatter_coalesced(
    tensors: List[Tensor],
    group: ProcessGroup = None,
) -> List[Tensor]:
    """simultaneously reduce-scatter a list of tensors - this can be done more
    efficiently than individual reduce scatter calls

    TODO. see if PyTorch team wants a c++ version of this for ProcessGroupNCCL
    """
    this_rank = dist.get_rank(group)
    world_sz = dist.get_world_size(group)

    partition_lst_for_each_tensor = [None] * len(tensors)
    for tensor_idx, tensor in enumerate(tensors):
        flattened_tensor = tensor.view(-1)
        chunk_sz = math.ceil(tensor.numel() / world_sz)
        partition_lst_for_each_tensor[tensor_idx] = [
            flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz]
            for rank in range(0, world_sz)
        ]

    padded_partition_sz_for_each_tensor = tuple(
        math.ceil(t.numel() / world_sz) for t in tensors)

    if len(tensors) == 1 and tensors[0].numel() % world_sz == 0:
        # if there's only one tensor being reduced and we don't need to pad
        # we have an opportunity to avoid a memory allocation
        tensor_partition_flat_buffer = tensors[0].view(-1)
    else:
        # interleave tensor partitions such that the correct reduced partitions of each tensor
        # end up at each rank
        tensor_partitions_lst_with_padding = []
        for rank in range(world_sz):
            for tensor_idx in range(len(tensors)):
                # add tensor content
                tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank]
                tensor_partitions_lst_with_padding.append(tensor_chunk)

                # add padding if necessary
                padding_sz = padded_partition_sz_for_each_tensor[
                    tensor_idx] - tensor_chunk.numel()
                if padding_sz > 0:
                    tensor_partitions_lst_with_padding.append(
                        torch.empty(padding_sz,
                                    dtype=tensor_chunk.dtype,
                                    device=tensor_chunk.device))

        tensor_partition_flat_buffer = instrument_w_nvtx(
            torch.cat)(tensor_partitions_lst_with_padding)

    tensor_partition_flat_buffer.div_(world_sz)  # pre-divide
    tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk(
        tensor_partition_flat_buffer, world_sz)

    # batched reduce-scatter call
    _torch_reduce_scatter_fn(tensor_partition_flat_buffer,
                             tensor_partition_buffer_for_each_rank[this_rank],
                             group=group)

    # reverse procedure of the interleaving done previously, done on the
    # result of the batched reduce-scatter
    output_lst: List[Tensor] = [None] * len(tensors)
    offset = 0
    for tensor_idx in range(len(tensors)):
        output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[
            this_rank].narrow(
                0, offset,
                partition_lst_for_each_tensor[tensor_idx][this_rank].numel())

        offset += padded_partition_sz_for_each_tensor[tensor_idx]

    return output_lst
Пример #21
0
    def replace_with_policy(child,
                            policy_cls,
                            triangular_masking,
                            inference=False,
                            layer_id=0):
        policy = policy_cls(child, inference=inference)

        if inference:
            hidden_size, num_attention_heads = policy.get_hidden_heads()
            assert num_attention_heads % mp_size == 0,\
                "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\
                "This is because the attention computation is partitioned evenly among the parallel GPUs."
        from deepspeed.moe.layer import MoE
        moe = False
        if hasattr(child, 'mlp') and isinstance(child.mlp, MoE):
            num_experts = child.mlp.num_experts
            moe = True

        attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention()
        if not moe or moe_type == 'standard':
            mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp()
        else:
            mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b, \
                _res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type)

        attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm()
        if quantize:
            if policy_cls is not HFBertLayerPolicy:
                qkvw = qkvw.to(torch.int8)
            dense_w = dense_w.to(torch.int8)
            _h4h_w = [moe_w1.to(torch.int8)
                      for moe_w1 in _h4h_w] if moe else _h4h_w.to(torch.int8)
            _4hh_w = [moe_w1.to(torch.int8)
                      for moe_w1 in _4hh_w] if moe else _4hh_w.to(torch.int8)
        elif fp16:
            qkvw = qkvw.half()
            dense_w = dense_w.half()
            _h4h_w = [moe_w1.half() for moe_w1 in _h4h_w] if moe else _h4h_w.half()
            _4hh_w = [moe_w1.half() for moe_w1 in _4hh_w] if moe else _4hh_w.half()
        if quantize or fp16:
            qkvb = qkvb if qkvb is None else qkvb.half()
            dense_b = dense_b if dense_b is None else dense_b.half()
            _h4h_b = [moe_b1.half() for moe_b1 in _h4h_b] if moe else _h4h_b.half()
            _4hh_b = [moe_b1.half() for moe_b1 in _4hh_b] if moe else _4hh_b.half()
            attn_nw = attn_nw if attn_nw is None else attn_nw.half()
            attn_nb = attn_nb if attn_nb is None else attn_nb.half()
            input_nw = input_nw.half()
            input_nb = input_nb.half()

        if moe and moe_type == 'residual' and fp16:
            _res_h4h_b = _res_h4h_b.half()
            _res_4hh_b = _res_4hh_b.half()
            _res_h4h_w = _res_h4h_w.half()
            _res_4hh_w = _res_4hh_w.half()
            _res_coef = _res_coef.half()

        #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group)

        if inference:
            if moe:
                ep_world_size = dist.get_world_size()
                local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size

                transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig(
                    hidden_size=hidden_size,
                    heads=num_attention_heads,
                    layer_norm_eps=config.layer_norm_eps if hasattr(
                        config,
                        'layer_norm_eps') else 1e-12,
                    fp16=fp16,
                    pre_layer_norm=policy.pre_attn_norm,
                    mp_size=mp_size,
                    q_int8=quantize,
                    moe_experts=local_ep_size,
                    global_experts=num_experts,
                    mlp_type=moe_type)
            else:
                rotary_dim = config.rotary_dim if hasattr(config, 'rotary_dim') else child.attention.rotary_ndims \
                                            if hasattr(child, 'attention') and hasattr(child.attention,'rotary_ndims') else -1
                bigscience_bloom = policy_cls is BLOOMLayerPolicy
                transformer_config = transformer_inference.DeepSpeedInferenceConfig(
                    hidden_size=hidden_size,
                    heads=num_attention_heads,
                    layer_norm_eps=config.layer_norm_eps if hasattr(
                        config,
                        'layer_norm_eps') else
                    (config.layer_norm_epsilon
                     if hasattr(config,
                                'layer_norm_epsilon') else config.layernorm_epsilon
                     if hasattr(config,
                                'layernorm_epsilon') else 1.0e-12),
                    fp16=fp16,
                    pre_layer_norm=policy.pre_attn_norm,
                    mp_size=mp_size,
                    q_int8=quantize,
                    return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)),
                    triangular_masking=(policy_cls is not HFBertLayerPolicy),
                    local_attention=((config.attention_layers[layer_id] == "local")
                                     if hasattr(config,
                                                'attention_layers') else False),
                    window_size=(config.window_size if hasattr(config,
                                                               'window_size') else 1),
                    rotary_dim=rotary_dim,
                    mlp_after_attn=(rotary_dim is None or rotary_dim < 0),
                    mlp_act_func_type=policy.mlp_act_func_type,
                    training_mp_size=training_mp_size,
                    bigscience_bloom=bigscience_bloom)

            if quantize and quantize_settings is not None:
                (quantization_scales,
                 merge_count,
                 mlp_extra_grouping,
                 quantize_groups) = quantize_settings
                if moe:
                    new_module = transformer_inference.DeepSpeedMoEInference(
                        transformer_config,
                        mp_group=mp_group,
                        ep_group=None if ep_group is None else ep_group[num_experts],
                        expert_mp_group=None
                        if expert_mp_group is None else expert_mp_group[num_experts],
                        quantize_scales=quantization_scales[layer_id],
                        quantize_groups=quantize_groups,
                        merge_count=merge_count,
                        mlp_extra_grouping=mlp_extra_grouping,
                        qkv_merging=(policy_cls is HFBertLayerPolicy))

                else:
                    new_module = transformer_inference.DeepSpeedTransformerInference(
                        transformer_config,
                        mp_group=mp_group,
                        quantize_scales=quantization_scales[layer_id],
                        quantize_groups=quantize_groups,
                        merge_count=merge_count,
                        mlp_extra_grouping=mlp_extra_grouping,
                        qkv_merging=(policy_cls is HFBertLayerPolicy))

                if quantize and qkvw.dtype != torch.int8:
                    quantize_bits = 8
                    quantizer = WeightQuantization()
                    if policy_cls is HFBertLayerPolicy:
                        data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3)
                    else:
                        data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups)
                    qkvw.data.copy_(data_quantized)
                    qkvw.data = qkvw.data.to(torch.int8)
            else:

                if moe:
                    new_module = transformer_inference.DeepSpeedMoEInference(
                        transformer_config,
                        mp_group=mp_group,
                        ep_group=None if ep_group is None else ep_group[num_experts],
                        expert_mp_group=None
                        if expert_mp_group is None else expert_mp_group[num_experts],
                    )

                else:
                    new_module = transformer_inference.DeepSpeedTransformerInference(
                        transformer_config,
                        mp_group=mp_group,
                    )
            new_module.config.scale_attention = scale_attention

            # we want the weights in [input, output] shape
            # linear layer is created with [input, output] shape
            # transpose it here to reduce inference cost!
            def transpose(data):
                # temp move to cpu to avoid requiring extra GPU memory during the reshape
                data = data.to('cpu')
                data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
                data = data.reshape(data.shape[-1], data.shape[-2])
                data.to(torch.cuda.current_device())
                return data

            attn_block = new_module.attention
            mpl_block = new_module.mlp

            if attn_linear_layer:
                if qkvw.numel() == 0 or qkvw.is_meta:
                    if qkvw.is_meta or qkvw.ds_tensor.numel(
                    ) < attn_block.attn_qkvw.numel():
                        pass
                    else:
                        with GatheredParameters([qkvw,
                                                 dense_w,
                                                 qkvb,
                                                 dense_b],
                                                modifier_rank=0):
                            qkvw = transpose(qkvw.data)
                            dense_w = transpose(dense_w.data)
                            qkvb = qkvb.data
                            dense_b = dense_b.data
                else:
                    qkvw.data = transpose(qkvw.data)
                    dense_w.data = transpose(dense_w.data)

            def _transpose(x):
                num_attention_heads_per_partition = transformer_config.heads // transformer_config.mp_size
                attention_head_size = x.shape[-1] // num_attention_heads_per_partition
                new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition,
                                               attention_head_size)
                x_1 = x.view(*new_x_shape)
                (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1))
                if len(q.shape) > 2:
                    return torch.cat((q.reshape(q.shape[0],
                                                -1),
                                      k.reshape(q.shape[0],
                                                -1),
                                      v.reshape(q.shape[0],
                                                -1)),
                                     dim=-1).reshape(x.shape)
                else:
                    return torch.cat((q.reshape(-1),
                                      k.reshape(-1),
                                      v.reshape(-1)),
                                     dim=-1).reshape(x.shape)

            if megatron_v2:
                new_module.config.rotate_half = True
                new_module.config.rotate_every_two = False

                # Note: this part needs to be added for BLOOM architecture
                qkvw = torch.nn.parameter.Parameter(_transpose(qkvw).contiguous())
                qkvb = torch.nn.parameter.Parameter(_transpose(qkvb).contiguous())

            # NOTE: This part caused instability in the multi-GPU inference!
            # TODO: This needs to be incorporated in the kernels.
            #dense_b = dense_b if dense_b is None else dense_b * (
            #    transformer_config.training_mp_size / transformer_config.mp_size)
            #_4hh_b = _4hh_b * (transformer_config.training_mp_size /
            #                   transformer_config.mp_size)

            if mlp_linear_layer:
                if not moe and (_4hh_w.numel() == 0 or _4hh_w.is_meta):
                    if _4hh_w.is_meta or _4hh_w.ds_tensor.numel(
                    ) < mpl_block.inter_w.numel():
                        pass
                    else:
                        with GatheredParameters([_h4h_w,
                                                 _4hh_w,
                                                 _4hh_b,
                                                 _h4h_b],
                                                modifier_rank=0):
                            _h4h_w = transpose(_h4h_w.data)
                            _4hh_w = transpose(_4hh_w.data)
                            _h4h_b = _h4h_b.data
                            _4hh_b = _4hh_b.data
                else:
                    _h4h_w = [transpose(moe_w1.data)
                              for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data)
                    _4hh_w = [transpose(moe_w1.data)
                              for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data)

            if moe and moe_type == 'residual':
                _res_h4h_w.data = transpose(_res_h4h_w.data)
                _res_4hh_w.data = transpose(_res_4hh_w.data)
                _res_coef.data = transpose(_res_coef.data)

            if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta:
                if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel():
                    pass
                else:
                    with GatheredParameters([
                            attn_block.attn_qkvw,
                            attn_block.attn_qkvb,
                            attn_block.attn_ow,
                            attn_block.attn_ob
                    ],
                                            modifier_rank=0):
                        attn_block.attn_qkvw = mp_replace.copy(
                            attn_block.attn_qkvw,
                            qkvw)
                        attn_block.attn_qkvb = mp_replace.copy(
                            attn_block.attn_qkvb,
                            qkvb)

                        attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
                        attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)
            else:
                if bigscience_bloom:
                    attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw)
                    attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb)
                else:
                    attn_block.attn_qkvw = mp_replace.qkv_copy(
                        attn_block.attn_qkvw,
                        qkvw)
                    attn_block.attn_qkvb = mp_replace.qkv_copy(
                        attn_block.attn_qkvb,
                        qkvb)

                attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
                attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)

            if moe:
                gpu_index = dist.get_rank()
                gpu_index = 0
                for ep_index in range(local_ep_size):
                    mpl_block[ep_index].inter_w.data = _h4h_w[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].inter_b.data = _h4h_b[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].output_w.data = _4hh_w[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].output_b.data = _4hh_b[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                new_module.attn_nw.data = attn_nw.to(torch.cuda.current_device())
                new_module.attn_nb.data = attn_nb.to(torch.cuda.current_device())
                if moe_type == 'residual':
                    new_module.res_mlp.inter_w.data = _res_h4h_w.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.inter_b.data = _res_h4h_b.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.output_w.data = _res_4hh_w.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.output_b.data = _res_4hh_b.to(
                        torch.cuda.current_device())
                    new_module.res_coef.data = _res_coef.to(torch.cuda.current_device())
            else:

                if _4hh_w.numel() == 0 or _4hh_w.is_meta:
                    if _4hh_w.is_meta or _4hh_w.ds_tensor.numel(
                    ) < mpl_block.inter_w.numel():
                        pass
                    else:
                        with GatheredParameters([_h4h_w,
                                                 _4hh_w,
                                                 _4hh_w,
                                                 _4hh_b],
                                                modifier_rank=0):
                            mpl_block.inter_w = mp_replace.copy(
                                mpl_block.inter_w,
                                _h4h_w)
                            mpl_block.inter_b = mp_replace.copy(
                                mpl_block.inter_b,
                                _h4h_b)
                            mpl_block.output_w = mp_replace.copy(
                                mpl_block.output_w,
                                _4hh_w)
                            mpl_block.output_b = mp_replace.copy(
                                mpl_block.output_b,
                                _4hh_b)
                else:
                    mpl_block.inter_w = mp_replace.copy(mpl_block.inter_w, _h4h_w)
                    mpl_block.inter_b = mp_replace.copy(mpl_block.inter_b, _h4h_b)
                    mpl_block.output_w = mp_replace.copy(mpl_block.output_w, _4hh_w)
                    mpl_block.output_b = mp_replace.copy(mpl_block.output_b, _4hh_b)

                if attn_nw is None:
                    new_module.mlp.attn_nw = attn_nw
                    new_module.mlp.attn_nb = attn_nb
                else:
                    if attn_nw.is_meta or attn_nw.numel() == 0:
                        if attn_nw.is_meta or attn_nw.ds_tensor.numel(
                        ) < new_module.mlp.attn_nw.numel():
                            pass
                        else:
                            with GatheredParameters([attn_nw, attn_nb], modifier_rank=0):
                                new_module.mlp.attn_nw.data.copy_(
                                    attn_nw.to(torch.cuda.current_device()))
                                new_module.mlp.attn_nb.data.copy_(
                                    attn_nb.to(torch.cuda.current_device()))
                    else:
                        new_module.mlp.attn_nw.data.copy_(
                            attn_nw.to(torch.cuda.current_device()))
                        new_module.mlp.attn_nb.data.copy_(
                            attn_nb.to(torch.cuda.current_device()))

            if input_nw.is_meta or input_nw.numel() == 0:
                if input_nw.is_meta or input_nw.ds_tensor.numel(
                ) < new_module.norm_w.numel():
                    pass
                else:
                    with GatheredParameters([input_nw, input_nb], modifier_rank=0):
                        new_module.norm_w.data.copy_(
                            input_nw.to(torch.cuda.current_device()))
                        new_module.norm_b.data.copy_(
                            input_nb.to(torch.cuda.current_device()))
            else:
                new_module.norm_w.data.copy_(input_nw.to(torch.cuda.current_device()))
                new_module.norm_b.data.copy_(input_nb.to(torch.cuda.current_device()))
        else:
            transformer_config = deepspeed.DeepSpeedTransformerConfig(
                batch_size=micro_batch_size if micro_batch_size > 0 else 1,
                hidden_size=config.hidden_size,
                heads=config.num_attention_heads,
                attn_dropout_ratio=config.attention_probs_dropout_prob,
                hidden_dropout_ratio=config.hidden_dropout_prob,
                num_hidden_layers=config.num_hidden_layers,
                initializer_range=config.initializer_range,
                layer_norm_eps=config.layer_norm_eps if hasattr(
                    config,
                    'layer_norm_eps') else 1e-12,
                seed=seed,
                fp16=fp16,
                pre_layer_norm=policy.pre_attn_norm,
                return_tuple=return_tuple,
                local_rank=local_rank,
                stochastic_mode=stochastic_mode,
                normalize_invertible=True,
                training=training)
            new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
            new_module.attn_qkvw.data = qkvw
            new_module.attn_qkvb.data = qkvb
            new_module.attn_ow.data = dense_w
            new_module.attn_ob.data = dense_b

            new_module.attn_nw.data = attn_nw
            new_module.attn_nb.data = attn_nb
            new_module.norm_w.data = input_nw
            new_module.norm_b.data = input_nb

            new_module.inter_w.data = _h4h_w
            new_module.inter_b.data = _h4h_b
            new_module.output_w.data = _4hh_w
            new_module.output_b.data = _4hh_b
        return new_module
Пример #22
0
def replace_transformer_layer(orig_layer_impl,
                              model,
                              policy=None,
                              micro_batch_size=-1,
                              config=None,
                              seed=-1,
                              hidden_size=-1,
                              num_attention_heads=-1,
                              mp_size=1,
                              training_mp_size=1,
                              mp_group=None,
                              ep_group=None,
                              expert_mp_group=None,
                              fp16=True,
                              local_rank=-1,
                              stochastic_mode=True,
                              training=True,
                              quantize=False,
                              quantize_settings=None,
                              triangular_masking=False,
                              return_tuple=True,
                              replace_with_kernel_inject=False,
                              linear_layer_setting=None,
                              moe=False,
                              moe_experts=1,
                              moe_type='standard',
                              checkpoint_dict=None,
                              save_mp_checkpoint_path=None):
    """ Replace bert-style transformer layers with DeepSpeed's transformer layer
    Arguments:
        orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
            e.g., transformers.modeling_bert.BertLayer.
        model (torch.nn.Module): user's nn.module representing their model
        policy: shows the policy for mapping from the orig_layer_impl to transformer parameters when
            replace_with_kernel_inject is set, otherwise, it provides the names of two linear layers as
            a tuple: (attention_output projection, transformer output projection)
        micro_batch_size (int): micro batch size per gpu used during training/eval
        config (dict): model config containing hidden size, attention heads, etc.
        seed (int): random seed value
        max_seq_length (int): max sequence length for training
        hidden_size (int): hidden dimension
        num_attention_heads (int): number of attention heads
        mp_size (int): model_parallelism degree
        mp_group : model_parallel group initialized on the modeling side
        preln (bool): does the original layer implementation do pre or post layer norm?
        fp16 (bool): fp16 or fp32
        local_rank (int): GPU rank (optional),
        stochastic_mode (bool): whether to use stochastic mode
        training (bool): specifying whether kernel-injection is done for training/inference (set to false for inference-mode injection)
        quantize_settings (tuple): this setting shows how we can quantize a model for running it through the inference kernels.
                It includes (quantization_scales, merge_count, mlp_extra_grouping, quantize_groups).
        return_tuple (bool): if set, transformer layer returns a tuple as the output.
            Note: this flag needs to be set for huggingface models.
        replace_with_kernel_inject (bool): injection_mode, if true, kernels will be add along with configuring
            Tensor-Parallelism
        linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers
            and embedding layers
        attention_params: (list of strings) [Optional]: shows the parameters in the attention part that needs to
            be adjusted based on the model-parallelism
    Returns:
        Updated nn.module with replaced transformer layers
    """
    mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group,
                                          mp_size=mp_size)  #, out_dim=0, in_dim=1)

    def replace_with_policy(child,
                            policy_cls,
                            triangular_masking,
                            inference=False,
                            layer_id=0):
        policy = policy_cls(child, inference=inference)

        if inference:
            hidden_size, num_attention_heads = policy.get_hidden_heads()
            assert num_attention_heads % mp_size == 0,\
                "To run the model parallel across the GPUs, the attention_heads require to be divisible by the world_size!" +\
                "This is because the attention computation is partitioned evenly among the parallel GPUs."
        from deepspeed.moe.layer import MoE
        moe = False
        if hasattr(child, 'mlp') and isinstance(child.mlp, MoE):
            num_experts = child.mlp.num_experts
            moe = True

        attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention()
        if not moe or moe_type == 'standard':
            mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp()
        else:
            mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b, \
                _res_h4h_w, _res_h4h_b, _res_4hh_w, _res_4hh_b, _res_coef = policy.mlp(moe_type)

        attn_nw, attn_nb, input_nw, input_nb = policy.layerNorm()
        if quantize:
            if policy_cls is not HFBertLayerPolicy:
                qkvw = qkvw.to(torch.int8)
            dense_w = dense_w.to(torch.int8)
            _h4h_w = [moe_w1.to(torch.int8)
                      for moe_w1 in _h4h_w] if moe else _h4h_w.to(torch.int8)
            _4hh_w = [moe_w1.to(torch.int8)
                      for moe_w1 in _4hh_w] if moe else _4hh_w.to(torch.int8)
        elif fp16:
            qkvw = qkvw.half()
            dense_w = dense_w.half()
            _h4h_w = [moe_w1.half() for moe_w1 in _h4h_w] if moe else _h4h_w.half()
            _4hh_w = [moe_w1.half() for moe_w1 in _4hh_w] if moe else _4hh_w.half()
        if quantize or fp16:
            qkvb = qkvb if qkvb is None else qkvb.half()
            dense_b = dense_b if dense_b is None else dense_b.half()
            _h4h_b = [moe_b1.half() for moe_b1 in _h4h_b] if moe else _h4h_b.half()
            _4hh_b = [moe_b1.half() for moe_b1 in _4hh_b] if moe else _4hh_b.half()
            attn_nw = attn_nw if attn_nw is None else attn_nw.half()
            attn_nb = attn_nb if attn_nb is None else attn_nb.half()
            input_nw = input_nw.half()
            input_nb = input_nb.half()

        if moe and moe_type == 'residual' and fp16:
            _res_h4h_b = _res_h4h_b.half()
            _res_4hh_b = _res_4hh_b.half()
            _res_h4h_w = _res_h4h_w.half()
            _res_4hh_w = _res_4hh_w.half()
            _res_coef = _res_coef.half()

        #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group)

        if inference:
            if moe:
                ep_world_size = dist.get_world_size()
                local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size

                transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig(
                    hidden_size=hidden_size,
                    heads=num_attention_heads,
                    layer_norm_eps=config.layer_norm_eps if hasattr(
                        config,
                        'layer_norm_eps') else 1e-12,
                    fp16=fp16,
                    pre_layer_norm=policy.pre_attn_norm,
                    mp_size=mp_size,
                    q_int8=quantize,
                    moe_experts=local_ep_size,
                    global_experts=num_experts,
                    mlp_type=moe_type)
            else:
                rotary_dim = config.rotary_dim if hasattr(config, 'rotary_dim') else child.attention.rotary_ndims \
                                            if hasattr(child, 'attention') and hasattr(child.attention,'rotary_ndims') else -1
                bigscience_bloom = policy_cls is BLOOMLayerPolicy
                transformer_config = transformer_inference.DeepSpeedInferenceConfig(
                    hidden_size=hidden_size,
                    heads=num_attention_heads,
                    layer_norm_eps=config.layer_norm_eps if hasattr(
                        config,
                        'layer_norm_eps') else
                    (config.layer_norm_epsilon
                     if hasattr(config,
                                'layer_norm_epsilon') else config.layernorm_epsilon
                     if hasattr(config,
                                'layernorm_epsilon') else 1.0e-12),
                    fp16=fp16,
                    pre_layer_norm=policy.pre_attn_norm,
                    mp_size=mp_size,
                    q_int8=quantize,
                    return_tuple=(return_tuple or (policy_cls is HFBertLayerPolicy)),
                    triangular_masking=(policy_cls is not HFBertLayerPolicy),
                    local_attention=((config.attention_layers[layer_id] == "local")
                                     if hasattr(config,
                                                'attention_layers') else False),
                    window_size=(config.window_size if hasattr(config,
                                                               'window_size') else 1),
                    rotary_dim=rotary_dim,
                    mlp_after_attn=(rotary_dim is None or rotary_dim < 0),
                    mlp_act_func_type=policy.mlp_act_func_type,
                    training_mp_size=training_mp_size,
                    bigscience_bloom=bigscience_bloom)

            if quantize and quantize_settings is not None:
                (quantization_scales,
                 merge_count,
                 mlp_extra_grouping,
                 quantize_groups) = quantize_settings
                if moe:
                    new_module = transformer_inference.DeepSpeedMoEInference(
                        transformer_config,
                        mp_group=mp_group,
                        ep_group=None if ep_group is None else ep_group[num_experts],
                        expert_mp_group=None
                        if expert_mp_group is None else expert_mp_group[num_experts],
                        quantize_scales=quantization_scales[layer_id],
                        quantize_groups=quantize_groups,
                        merge_count=merge_count,
                        mlp_extra_grouping=mlp_extra_grouping,
                        qkv_merging=(policy_cls is HFBertLayerPolicy))

                else:
                    new_module = transformer_inference.DeepSpeedTransformerInference(
                        transformer_config,
                        mp_group=mp_group,
                        quantize_scales=quantization_scales[layer_id],
                        quantize_groups=quantize_groups,
                        merge_count=merge_count,
                        mlp_extra_grouping=mlp_extra_grouping,
                        qkv_merging=(policy_cls is HFBertLayerPolicy))

                if quantize and qkvw.dtype != torch.int8:
                    quantize_bits = 8
                    quantizer = WeightQuantization()
                    if policy_cls is HFBertLayerPolicy:
                        data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3)
                    else:
                        data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups)
                    qkvw.data.copy_(data_quantized)
                    qkvw.data = qkvw.data.to(torch.int8)
            else:

                if moe:
                    new_module = transformer_inference.DeepSpeedMoEInference(
                        transformer_config,
                        mp_group=mp_group,
                        ep_group=None if ep_group is None else ep_group[num_experts],
                        expert_mp_group=None
                        if expert_mp_group is None else expert_mp_group[num_experts],
                    )

                else:
                    new_module = transformer_inference.DeepSpeedTransformerInference(
                        transformer_config,
                        mp_group=mp_group,
                    )
            new_module.config.scale_attention = scale_attention

            # we want the weights in [input, output] shape
            # linear layer is created with [input, output] shape
            # transpose it here to reduce inference cost!
            def transpose(data):
                # temp move to cpu to avoid requiring extra GPU memory during the reshape
                data = data.to('cpu')
                data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
                data = data.reshape(data.shape[-1], data.shape[-2])
                data.to(torch.cuda.current_device())
                return data

            attn_block = new_module.attention
            mpl_block = new_module.mlp

            if attn_linear_layer:
                if qkvw.numel() == 0 or qkvw.is_meta:
                    if qkvw.is_meta or qkvw.ds_tensor.numel(
                    ) < attn_block.attn_qkvw.numel():
                        pass
                    else:
                        with GatheredParameters([qkvw,
                                                 dense_w,
                                                 qkvb,
                                                 dense_b],
                                                modifier_rank=0):
                            qkvw = transpose(qkvw.data)
                            dense_w = transpose(dense_w.data)
                            qkvb = qkvb.data
                            dense_b = dense_b.data
                else:
                    qkvw.data = transpose(qkvw.data)
                    dense_w.data = transpose(dense_w.data)

            def _transpose(x):
                num_attention_heads_per_partition = transformer_config.heads // transformer_config.mp_size
                attention_head_size = x.shape[-1] // num_attention_heads_per_partition
                new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition,
                                               attention_head_size)
                x_1 = x.view(*new_x_shape)
                (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1))
                if len(q.shape) > 2:
                    return torch.cat((q.reshape(q.shape[0],
                                                -1),
                                      k.reshape(q.shape[0],
                                                -1),
                                      v.reshape(q.shape[0],
                                                -1)),
                                     dim=-1).reshape(x.shape)
                else:
                    return torch.cat((q.reshape(-1),
                                      k.reshape(-1),
                                      v.reshape(-1)),
                                     dim=-1).reshape(x.shape)

            if megatron_v2:
                new_module.config.rotate_half = True
                new_module.config.rotate_every_two = False

                # Note: this part needs to be added for BLOOM architecture
                qkvw = torch.nn.parameter.Parameter(_transpose(qkvw).contiguous())
                qkvb = torch.nn.parameter.Parameter(_transpose(qkvb).contiguous())

            # NOTE: This part caused instability in the multi-GPU inference!
            # TODO: This needs to be incorporated in the kernels.
            #dense_b = dense_b if dense_b is None else dense_b * (
            #    transformer_config.training_mp_size / transformer_config.mp_size)
            #_4hh_b = _4hh_b * (transformer_config.training_mp_size /
            #                   transformer_config.mp_size)

            if mlp_linear_layer:
                if not moe and (_4hh_w.numel() == 0 or _4hh_w.is_meta):
                    if _4hh_w.is_meta or _4hh_w.ds_tensor.numel(
                    ) < mpl_block.inter_w.numel():
                        pass
                    else:
                        with GatheredParameters([_h4h_w,
                                                 _4hh_w,
                                                 _4hh_b,
                                                 _h4h_b],
                                                modifier_rank=0):
                            _h4h_w = transpose(_h4h_w.data)
                            _4hh_w = transpose(_4hh_w.data)
                            _h4h_b = _h4h_b.data
                            _4hh_b = _4hh_b.data
                else:
                    _h4h_w = [transpose(moe_w1.data)
                              for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data)
                    _4hh_w = [transpose(moe_w1.data)
                              for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data)

            if moe and moe_type == 'residual':
                _res_h4h_w.data = transpose(_res_h4h_w.data)
                _res_4hh_w.data = transpose(_res_4hh_w.data)
                _res_coef.data = transpose(_res_coef.data)

            if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta:
                if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel():
                    pass
                else:
                    with GatheredParameters([
                            attn_block.attn_qkvw,
                            attn_block.attn_qkvb,
                            attn_block.attn_ow,
                            attn_block.attn_ob
                    ],
                                            modifier_rank=0):
                        attn_block.attn_qkvw = mp_replace.copy(
                            attn_block.attn_qkvw,
                            qkvw)
                        attn_block.attn_qkvb = mp_replace.copy(
                            attn_block.attn_qkvb,
                            qkvb)

                        attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
                        attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)
            else:
                if bigscience_bloom:
                    attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw)
                    attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb)
                else:
                    attn_block.attn_qkvw = mp_replace.qkv_copy(
                        attn_block.attn_qkvw,
                        qkvw)
                    attn_block.attn_qkvb = mp_replace.qkv_copy(
                        attn_block.attn_qkvb,
                        qkvb)

                attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w)
                attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)

            if moe:
                gpu_index = dist.get_rank()
                gpu_index = 0
                for ep_index in range(local_ep_size):
                    mpl_block[ep_index].inter_w.data = _h4h_w[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].inter_b.data = _h4h_b[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].output_w.data = _4hh_w[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                    mpl_block[ep_index].output_b.data = _4hh_b[
                        gpu_index * local_ep_size + ep_index].to(
                            torch.cuda.current_device())
                new_module.attn_nw.data = attn_nw.to(torch.cuda.current_device())
                new_module.attn_nb.data = attn_nb.to(torch.cuda.current_device())
                if moe_type == 'residual':
                    new_module.res_mlp.inter_w.data = _res_h4h_w.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.inter_b.data = _res_h4h_b.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.output_w.data = _res_4hh_w.to(
                        torch.cuda.current_device())
                    new_module.res_mlp.output_b.data = _res_4hh_b.to(
                        torch.cuda.current_device())
                    new_module.res_coef.data = _res_coef.to(torch.cuda.current_device())
            else:

                if _4hh_w.numel() == 0 or _4hh_w.is_meta:
                    if _4hh_w.is_meta or _4hh_w.ds_tensor.numel(
                    ) < mpl_block.inter_w.numel():
                        pass
                    else:
                        with GatheredParameters([_h4h_w,
                                                 _4hh_w,
                                                 _4hh_w,
                                                 _4hh_b],
                                                modifier_rank=0):
                            mpl_block.inter_w = mp_replace.copy(
                                mpl_block.inter_w,
                                _h4h_w)
                            mpl_block.inter_b = mp_replace.copy(
                                mpl_block.inter_b,
                                _h4h_b)
                            mpl_block.output_w = mp_replace.copy(
                                mpl_block.output_w,
                                _4hh_w)
                            mpl_block.output_b = mp_replace.copy(
                                mpl_block.output_b,
                                _4hh_b)
                else:
                    mpl_block.inter_w = mp_replace.copy(mpl_block.inter_w, _h4h_w)
                    mpl_block.inter_b = mp_replace.copy(mpl_block.inter_b, _h4h_b)
                    mpl_block.output_w = mp_replace.copy(mpl_block.output_w, _4hh_w)
                    mpl_block.output_b = mp_replace.copy(mpl_block.output_b, _4hh_b)

                if attn_nw is None:
                    new_module.mlp.attn_nw = attn_nw
                    new_module.mlp.attn_nb = attn_nb
                else:
                    if attn_nw.is_meta or attn_nw.numel() == 0:
                        if attn_nw.is_meta or attn_nw.ds_tensor.numel(
                        ) < new_module.mlp.attn_nw.numel():
                            pass
                        else:
                            with GatheredParameters([attn_nw, attn_nb], modifier_rank=0):
                                new_module.mlp.attn_nw.data.copy_(
                                    attn_nw.to(torch.cuda.current_device()))
                                new_module.mlp.attn_nb.data.copy_(
                                    attn_nb.to(torch.cuda.current_device()))
                    else:
                        new_module.mlp.attn_nw.data.copy_(
                            attn_nw.to(torch.cuda.current_device()))
                        new_module.mlp.attn_nb.data.copy_(
                            attn_nb.to(torch.cuda.current_device()))

            if input_nw.is_meta or input_nw.numel() == 0:
                if input_nw.is_meta or input_nw.ds_tensor.numel(
                ) < new_module.norm_w.numel():
                    pass
                else:
                    with GatheredParameters([input_nw, input_nb], modifier_rank=0):
                        new_module.norm_w.data.copy_(
                            input_nw.to(torch.cuda.current_device()))
                        new_module.norm_b.data.copy_(
                            input_nb.to(torch.cuda.current_device()))
            else:
                new_module.norm_w.data.copy_(input_nw.to(torch.cuda.current_device()))
                new_module.norm_b.data.copy_(input_nb.to(torch.cuda.current_device()))
        else:
            transformer_config = deepspeed.DeepSpeedTransformerConfig(
                batch_size=micro_batch_size if micro_batch_size > 0 else 1,
                hidden_size=config.hidden_size,
                heads=config.num_attention_heads,
                attn_dropout_ratio=config.attention_probs_dropout_prob,
                hidden_dropout_ratio=config.hidden_dropout_prob,
                num_hidden_layers=config.num_hidden_layers,
                initializer_range=config.initializer_range,
                layer_norm_eps=config.layer_norm_eps if hasattr(
                    config,
                    'layer_norm_eps') else 1e-12,
                seed=seed,
                fp16=fp16,
                pre_layer_norm=policy.pre_attn_norm,
                return_tuple=return_tuple,
                local_rank=local_rank,
                stochastic_mode=stochastic_mode,
                normalize_invertible=True,
                training=training)
            new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
            new_module.attn_qkvw.data = qkvw
            new_module.attn_qkvb.data = qkvb
            new_module.attn_ow.data = dense_w
            new_module.attn_ob.data = dense_b

            new_module.attn_nw.data = attn_nw
            new_module.attn_nb.data = attn_nb
            new_module.norm_w.data = input_nw
            new_module.norm_b.data = input_nb

            new_module.inter_w.data = _h4h_w
            new_module.inter_b.data = _h4h_b
            new_module.output_w.data = _4hh_w
            new_module.output_b.data = _4hh_b
        return new_module

    def replace_wo_policy(module, all_reduce_linears):
        def _replace(child, name, conv_linear_layer):
            mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
            z_inference = (len(list(child.parameters())) > 0) and (list(
                child.parameters())[0].numel() == 0)
            if z_inference:
                weight_shape = child.weight.ds_shape
            else:
                weight_shape = child.weight.shape
            if name in all_reduce_linears:
                new_weight = torch.empty((
                    weight_shape[1] if conv_linear_layer else weight_shape[0],
                    (weight_shape[0] if conv_linear_layer else weight_shape[1]) //
                    mp_size,
                ),
                                         device=child.weight.device,
                                         dtype=child.weight.dtype)
                if z_inference:
                    with deepspeed.zero.GatheredParameters(child.weight,
                                                           modifier_rank=0):
                        data = child.weight.data.to(new_weight.device)
                        if conv_linear_layer:
                            data = data.transpose(-1, -2).contiguous()
                        data = mp_replace.copy(new_weight, data)
                    child.weight.ds_tensor = torch.empty(1)
                else:
                    if conv_linear_layer:
                        child.weight.data = child.weight.data.transpose(-1,
                                                                        -2).contiguous()
                    data = mp_replace.copy(new_weight, child.weight.data)
                new_bias = torch.empty((weight_shape[0]),
                                       device=child.weight.device,
                                       dtype=child.weight.dtype)
                if z_inference:
                    with deepspeed.zero.GatheredParameters(child.bias, modifier_rank=0):
                        new_bias.data.copy_(child.bias.data)
                elif child.bias is not None:
                    new_bias.data.copy_(child.bias.data)
                return LinearAllreduce(data, child.bias if child.bias is None else \
                            torch.nn.parameter.Parameter(new_bias.to(torch.cuda.current_device())), mp_group)
            else:
                new_weight = torch.empty((
                    (weight_shape[1] if conv_linear_layer else weight_shape[0]) //
                    mp_size,
                    weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1],
                ),
                                         device=child.weight.device,
                                         dtype=child.weight.dtype)
                if z_inference:
                    with deepspeed.zero.GatheredParameters(child.weight,
                                                           modifier_rank=0):
                        data = child.weight.data.to(new_weight.device)
                        if conv_linear_layer:
                            data = data.transpose(-1, -2).contiguous()
                        data = mp_replace.copy(new_weight, data)
                    child.weight.ds_tensor = torch.empty(1)
                else:
                    if conv_linear_layer:
                        child.weight.data = child.weight.data.transpose(-1,
                                                                        -2).contiguous()
                    data = mp_replace.copy(new_weight, child.weight.data)

                new_bias = torch.empty((weight_shape[0] // mp_size),
                                       device=child.weight.device,
                                       dtype=child.weight.dtype)
                if z_inference:
                    with deepspeed.zero.GatheredParameters(child.bias, modifier_rank=0):
                        bias_data = None if child.bias is None else mp_replace.copy(
                            new_bias,
                            child.bias.data).to(torch.cuda.current_device())
                else:
                    bias_data = None if child.bias is None else mp_replace.copy(
                        new_bias,
                        child.bias.data).to(torch.cuda.current_device())
                return LinearLayer(weight=data.to(torch.cuda.current_device()),
                                   bias=bias_data)

        def _slice_embedding(child, name, conv_linear_layer):
            mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
            new_weight = torch.empty((child.weight.shape[0],
                                      child.weight.shape[1] // mp_size),
                                     device=child.weight.device,
                                     dtype=child.weight.dtype)
            data = mp_replace.copy(new_weight,
                                   child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \
                                   child.weight.data)
            new_embedding = nn.Embedding(child.weight.shape[0],
                                         child.weight.shape[1] // mp_size)
            new_embedding.weight.data.copy_(data)
            return new_embedding

        def update_mp_params(child):
            if hasattr(child, 'n_heads'):
                child.n_heads = child.n_heads // mp_size
            if hasattr(child, 'inner_dim'):
                child.inner_dim = child.inner_dim // mp_size
            if hasattr(child, 'num_heads'):
                child.num_heads = child.num_heads // mp_size
            if hasattr(child, 'num_attention_heads'):
                child.num_attention_heads = child.num_attention_heads // mp_size
            if hasattr(child, 'all_head_size'):
                child.all_head_size = child.all_head_size // mp_size
            if hasattr(child, 'embed_dim'):
                child.embed_dim = child.embed_dim // mp_size
            if hasattr(child, 'hidden_size'):
                child.hidden_size = child.hidden_size // mp_size

        conv_linear_layer = False
        if linear_layer_setting is not None:
            linear_policies = {linear_layer_setting[0]: _replace}
            if len(linear_layer_setting) == 2:
                linear_policies.update({linear_layer_setting[1]: _slice_embedding})
        else:
            if orig_layer_impl is HFGPT2LayerPolicy._orig_layer_class:
                try:
                    import transformers
                    conv_linear_layer = True
                    linear_policies = {transformers.model_utils.Conv1D: _replace}
                except ImportError:
                    linear_policies = {nn.Linear: _replace}
            else:
                linear_policies = {nn.Linear: _replace, nn.Embedding: _slice_embedding}

        def _replace_module(r_module, prev_name=''):
            for name, child in r_module.named_children():
                if child.__class__ in linear_policies:
                    setattr(
                        r_module,
                        name,
                        linear_policies[child.__class__](child,
                                                         prev_name + '.' + name,
                                                         conv_linear_layer))
                else:
                    update_mp_params(child)
                    _replace_module(child, name)
            return r_module

        return _replace_module(module)

    def replace_fn(child, _policy, layer_id=0):
        if training:
            # copy relevant state from child -> new module
            new_module = replace_with_policy(child, _policy, triangular_masking)

        else:
            # copy relevant state from child -> new module
            if replace_with_kernel_inject:
                new_module = replace_with_policy(child,
                                                 _policy,
                                                 triangular_masking,
                                                 inference=True,
                                                 layer_id=layer_id)
            else:
                new_module = replace_wo_policy(child, _policy)

        return new_module

    replaced_module = replace_module(model=model,
                                     orig_class=orig_layer_impl,
                                     replace_fn=replace_fn,
                                     _replace_policy=policy)

    world_size = dist.get_world_size() if dist.is_initialized() else 1
    rank = dist.get_rank() if dist.is_initialized() else 0
    if checkpoint_dict is not None:
        start_time = time.time()
        checkpoint = checkpoint_dict['checkpoints']
        ckpt_type = checkpoint_dict.get('parallelization', 'pp')
        ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size)
        base_dir = checkpoint_dict.get('base_dir', '')

        if ckpt_type == 'pp':
            pbar = tqdm.tqdm(total=len(checkpoint),
                             desc=f"Loading {len(checkpoint)} checkpoint shards")
            for i in range(len(checkpoint)):
                if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
                    pbar.update(1)
                sd = torch.load(checkpoint[i], map_location='cpu')
                load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type)
        else:
            num_checkpoints = len(checkpoint) // ckpt_mp_size
            assert world_size >= ckpt_mp_size,\
                "Currently, merging checkpoints is not supported (when world_size is smaller than #checkpoints)!"
            checkpoint_stride = world_size // ckpt_mp_size
            if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
                pbar = tqdm.tqdm(total=num_checkpoints,
                                 desc=f"Loading {num_checkpoints} checkpoint shards")
            for i in range(num_checkpoints):
                if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
                    pbar.update(1)

                ckpt_index = i * ckpt_mp_size + (rank // checkpoint_stride)
                ckpt_file = os.path.join(
                    base_dir,
                    checkpoint[ckpt_index]) if base_dir else checkpoint[ckpt_index]
                sd = torch.load(ckpt_file, map_location='cpu')
                load_model_with_checkpoint(replaced_module,
                                           sd,
                                           mp_replace,
                                           ckpt_type,
                                           rank % (world_size // ckpt_mp_size))
        print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")

    if save_mp_checkpoint_path is not None:
        from collections import OrderedDict
        import json

        if checkpoint_dict is None:
            ckpt_name = "ds_model"
            try:
                from transformers.models.bloom.modeling_bloom import BloomForCausalLM
                if isinstance(model, BloomForCausalLM):
                    ckpt_name = "bloom"
            except ImportError:
                ckpt_name = "ds_model"
        else:
            ckpt_name = checkpoint_dict['type']
        if dist.is_initialized():
            dist.barrier()
        transformer_name = get_transformer_name(replaced_module)
        non_tp_ckpt_name = f'{ckpt_name}-non-tp.pt'
        ckpt_files = [non_tp_ckpt_name] * world_size
        os.makedirs(save_mp_checkpoint_path, exist_ok=True)
        if not dist.is_initialized() or dist.get_rank() == 0:
            print("Saving tp-sharded checkpoints")
            torch.save(
                OrderedDict({
                    k: v
                    for k,
                    v in dict(replaced_module.state_dict()).items()
                    if transformer_name not in k
                }),
                f'{save_mp_checkpoint_path}/{non_tp_ckpt_name}')
            ckpt_files += [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)]
            config = json.dumps({
                'type': ckpt_name,
                'base_dir': f'{save_mp_checkpoint_path}',
                'checkpoints': ckpt_files,
                'version': 1.0,
                'parallelization': 'tp',
                'mp_size': world_size
            })
            with open(f"{save_mp_checkpoint_path}/{ckpt_name}_ds-inference_config.json",
                      "w") as cfg:
                cfg.write(config)
        torch.save(
            OrderedDict({
                k: v
                for k,
                v in dict(replaced_module.state_dict()).items() if transformer_name in k
            }),
            f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{rank:0>2d}.pt')

    return replaced_module
Пример #23
0
 def test_world_size_1(self):
     assert dist.get_world_size() == 1
Пример #24
0
 def test_two(self, number, color="purple"):
     assert dist.get_world_size() == 2
     assert number == 1138
     assert color == "purple"
Пример #25
0
 def test_one(self, number):
     assert dist.get_world_size() == 2
     assert number == 1138
Пример #26
0
 def test(self):
     assert dist.is_initialized()
     assert dist.get_world_size() == 3
     assert dist.get_rank() < 3
Пример #27
0
from deepspeed.utils.timer import SynchronizedWallClockTimer
from statistics import mean

timers = SynchronizedWallClockTimer()

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
args = parser.parse_args()

deepspeed.init_distributed(dist_backend='nccl')
args.local_rank = int(os.environ['LOCAL_RANK'])

torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)

size = dist.get_world_size()
rank = dist.get_rank()

backend = NcclBackend()
local_rank = args.local_rank

# Setting tensor_size (BERT-Large)
tensor_size = 300 * 2**20
server_size = int(tensor_size / size)
if tensor_size % (8 * size) != 0:
    right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size)))
else:
    right_tensor_size = tensor_size
right_server_size = right_tensor_size // size

# Adding bias to the initialization of the gradient we are communicating
Пример #28
0
 def _is_grid_valid(self):
     ranks = 1
     for ax in self._topo.get_axis_names():
         ranks *= self._topo.get_dim(ax)
     return ranks == dist.get_world_size()
Пример #29
0
    def test(self, tmpdir):
        from deepspeed.runtime.comm.nccl import NcclBackend

        size = dist.get_world_size()
        rank = dist.get_rank()
        backend = NcclBackend()
        local_rank = dist.get_rank()
        device = torch.device("cuda", dist.get_rank())

        # A simulated compression function using deepspeed.comm
        def torch_sim(a):
            a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
            scale = a.norm() / np.sqrt(a.numel())
            a_compressed = scale * a_sign
            a_sign = None
            worker_error = a - a_compressed
            dist.all_reduce(a_compressed)
            a_compressed.mul_(1 / dist.get_world_size())
            a_server_sign = (a_compressed.sign().add_(1).bool().float().add_(
                -0.5).mul_(2.0))
            a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
            server_scale = [
                chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list
            ]
            a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
            a_server_compressed = torch.cat([
                server_scale[i] * a_sign_list[i]
                for i in range(dist.get_world_size())
            ])
            rank = dist.get_rank()
            server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
            torch.cuda.synchronize()
            dist.barrier()
            return a_server_compressed, worker_error, server_error

        tensor_size = 300 * 2**20
        server_size = int(tensor_size / size)
        if tensor_size % (8 * size) != 0:
            right_tensor_size = tensor_size + (8 * size - (tensor_size %
                                                           (8 * size)))
        else:
            right_tensor_size = tensor_size
        right_server_size = right_tensor_size // size

        # Adding bias to the initialization of the gradient we are communicating
        # In order to get rid of the case where some elements in the gradient are too small
        a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank

        worker_error = torch.zeros(right_tensor_size, device=device)
        server_error = torch.zeros(right_server_size, device=device)

        a_torch, worker_error_torch, server_error_torch = torch_sim(a)
        torch.cuda.empty_cache()

        a_after = backend.compressed_allreduce(a, worker_error, server_error,
                                               local_rank)

        threshold = 1e-6
        magnitude_threshold = 1e-6
        diff_mask = (a_after - a_torch) > threshold
        diff_server_mask = torch.chunk(diff_mask, size)[rank]
        mpi_server = torch.chunk(a_after, size)[rank] + server_error
        torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch

        # If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic
        # The test would skip those numbers that are too small in compensated_server_m
        check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold
        if torch.sum(check_mag_mask) != 0:
            print("Fails at {} of positions".format(torch.sum(check_mag_mask)))
        assert torch.sum(diff_server_mask) == 0 or torch.sum(
            check_mag_mask) == 0
Пример #30
0
    def forward(ctx, input, input_mask, head_mask, layer_past, get_present,
                encoder_hidden_states, encoder_attention_mask,
                output_attentions, norm_w, norm_b, config, attn_qkvw,
                attn_qkvb, num_attention_heads_per_partition, norm_factor,
                hidden_size_per_partition, attn_ow, attn_ob, mp_group,
                q_scales, q_groups, merge_count, qkv_merging,
                score_context_func, alibi):
        def _transpose_for_scores(x, key=False, reshape=False):
            attention_head_size = x.shape[
                -1] // num_attention_heads_per_partition
            new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition,
                                           attention_head_size)
            x_1 = x.view(*new_x_shape)
            if key:
                x_1 = x_1.permute(0, 2, 3, 1)
            else:
                x_1 = x_1.permute(0, 2, 1, 3)
            if reshape:
                return x_1.reshape(x.shape)
            return x_1.contiguous()

        def _transpose_for_context(x):
            x = x.permute(0, 2, 1, 3).contiguous()
            new_x_layer_shape = x.size()[:-2] + \
                                      (hidden_size_per_partition,)
            return x.view(*new_x_layer_shape).contiguous()

        ########### This part is taken/modified form the HF modeling_bloom.py ################
        # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py

        def split_tensor_along_last_dim(tensor,
                                        num_partitions,
                                        contiguous_split_chunks=True):
            """Split a tensor along its last dimension.

            Args:
                tensor: ([`torch.tensor`], *required*):
                    input tensor to split
                num_partitions ([`int`], *required*):
                    number of partitions to split the tensor
                contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
                    If True, make each chunk contiguous in memory.
            """
            # Get the size and dimension.
            last_dim = tensor.dim() - 1
            numerator, denominator = tensor.size()[last_dim], num_partitions
            if not (numerator % denominator == 0):
                raise ValueError(
                    f"{numerator} is not divisible by {denominator}")
            last_dim_size = numerator // denominator
            # Split.
            tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
            # Note: torch.split does not create contiguous tensors by default.
            if contiguous_split_chunks:
                return tuple(chunk.contiguous() for chunk in tensor_list)

            return tensor_list

        def backup_attention(mixed_x_layer, layer_past, alibi, input_mask,
                             norm_factor):
            alibi = alibi.to(torch.cuda.current_device())
            head_dim = hidden_size_per_partition // num_attention_heads_per_partition
            new_tensor_shape = mixed_x_layer.size()[:-1] + (
                num_attention_heads_per_partition, 3 * head_dim)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            (query_layer, key_layer,
             value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

            if layer_past is not None:
                past_key, past_value = layer_past
                # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
                key_layer = torch.cat((past_key.type_as(key_layer), key_layer),
                                      dim=1)
                value_layer = torch.cat(
                    (past_value.type_as(value_layer), value_layer), dim=1)

            presents = (key_layer, value_layer)

            # [batch_size, head_dim, q_length, k_length]
            output_size = (query_layer.size(0), query_layer.size(2),
                           query_layer.size(1), key_layer.size(1))
            # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
            query_layer = query_layer.transpose(1, 0).reshape(
                output_size[2], output_size[0] * output_size[1], -1)
            # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
            key_layer = key_layer.transpose(1, 0).reshape(
                output_size[3], output_size[0] * output_size[1], -1)

            # Raw attention scores. [batch_size * num_heads, q_length, k_length]
            matmul_result = torch.matmul(
                query_layer.transpose(1, 0),
                key_layer.transpose(1, 0).transpose(1, 2))
            # change view to [batch_size, num_heads, q_length, k_length]
            attention_scores = matmul_result.view(*output_size)

            offset = dist.get_rank(
            ) * num_attention_heads_per_partition if dist.is_initialized(
            ) else 0
            attention_probs = inference_cuda_module.softmax_fp16(
                attention_scores,
                ((1 - input_mask).half() *
                 minus_inf) if input_mask.dtype == torch.int64 else input_mask,
                alibi, (config.triangular_masking and
                        (attention_scores.shape[-2] > 1)), False, False, 1,
                False, 1 / (norm_factor * norm_factor), offset, config.mp_size)
            # change view [batch_size x num_heads, q_length, k_length]
            attention_probs_reshaped = attention_probs.view(
                *matmul_result.shape)

            # matmul: [batch_size * num_heads, q_length, head_dim]
            context_layer = torch.bmm(
                attention_probs_reshaped,
                value_layer.transpose(1, 2).reshape(-1, value_layer.size(1),
                                                    value_layer.size(3)))

            # change view [batch_size, num_heads, q_length, head_dim]
            context_layer = context_layer.view(
                context_layer.size(0) // num_attention_heads_per_partition,
                num_attention_heads_per_partition, context_layer.size(1),
                context_layer.shape[-1])

            context_layer = _transpose_for_context(context_layer)

            return context_layer, presents

        ###################### End of HF modeling_bloom addition ########################

        def compute_attention(qkv_out, input_mask):
            no_masking = input_mask is None

            head_size = (qkv_out.shape[-1] // 3 //
                         num_attention_heads_per_partition)
            if no_masking:
                input_mask = torch.empty(1)
            if merge_count > 0 and config.q_int8:
                split_dim = (qkv_out.dim() - 1)
                qkv_split = torch.split(qkv_out, (qkv_out.shape[-1] //
                                                  (2**merge_count)),
                                        dim=split_dim)
                qkv_split = [
                    torch.split(s, (s.shape[-1] // 3), dim=split_dim)
                    for s in qkv_split
                ]
                (mixed_query, key_layer, value_layer) = [
                    torch.cat([s[i] for s in qkv_split], axis=-1)
                    for i in range(len(qkv_split[0]))
                ]

                if config.rotary_dim > 0:
                    mixed_query, key_layer = inference_cuda_module.apply_rotary_pos_emb(
                        mixed_query, key_layer, config.rotary_dim,
                        0 if layer_past is None else layer_past[0].shape[-2],
                        num_attention_heads_per_partition, config.rotate_half,
                        config.rotate_every_two)
                if layer_past is not None:
                    past_key, past_value = layer_past
                    key_layer = torch.cat(
                        (past_key.type_as(key_layer), key_layer), dim=-2)
                    value_layer = torch.cat(
                        (past_value.type_as(value_layer), value_layer), dim=-2)
                presents = (key_layer, value_layer)
                mixed_query = _transpose_for_scores(mixed_query, False, True)
                key_layer = _transpose_for_scores(key_layer, True, True) / (
                    norm_factor if config.scale_attention else 1.0)
                value_layer = _transpose_for_scores(value_layer, False, True)
                if layer_past is None:
                    attn_key_value = score_context_func(
                        mixed_query,
                        key_layer,
                        torch.empty(1),
                        ((1 - input_mask).half() * minus_inf)
                        if input_mask.dtype == torch.int64 else input_mask,
                        value_layer,
                        torch.empty(1),
                        num_attention_heads_per_partition,
                        (1 / norm_factor if config.scale_attention else 1.0),
                        (not unfused_mode),  # noqa: F821
                        config.triangular_masking,
                        config.local_attention,
                        config.window_size,
                        no_masking)
                else:
                    attn_key_value = score_context_func(
                        mixed_query,
                        (key_layer if unfused_mode else
                         past_key.type_as(key_layer)),  # noqa: F821
                        key_layer,
                        ((1 - input_mask).half() * minus_inf)
                        if input_mask.dtype == torch.int64 else input_mask,
                        (value_layer if unfused_mode else
                         past_value.type_as(value_layer)),  # noqa: F821
                        value_layer,
                        num_attention_heads_per_partition,
                        (1 / norm_factor if config.scale_attention else 1.0),
                        (not unfused_mode),  # noqa: F821
                        config.triangular_masking,
                        config.local_attention,
                        config.window_size,
                        no_masking)
                if unfused_mode:  # noqa: F821
                    context_layer, _, _ = attn_key_value
                else:
                    context_layer, key_layer, value_layer = attn_key_value

                # Transpose Context
                context_layer = _transpose_for_context(context_layer)

                return context_layer, presents[0], presents[
                    1]  # atten_output, key_layer, value_layer
            else:
                # Note: This modification is added for the BLOOM-176B model and will be removed later!
                if config.bigscience_bloom:
                    context_layer, presents = backup_attention(
                        qkv_out, layer_past, alibi, input_mask, norm_factor)
                    return context_layer, presents[0], presents[
                        1]  #key_layer, value_layer
                else:
                    if alibi is not None:
                        batch_heads = qkv_out.shape[
                            0] * num_attention_heads_per_partition
                        offset = dist.get_rank(
                        ) * batch_heads if dist.is_initialized() else 0
                        sliced_alibi = alibi[offset:batch_heads + offset, :, :]

                    attn_key_value = score_context_func(
                        qkv_out,
                        ((1 - input_mask).to(qkv_out.dype) * minus_inf)
                        if input_mask.dtype == torch.int64 else input_mask,
                        config.rotary_dim, config.rotate_half,
                        config.rotate_every_two,
                        num_attention_heads_per_partition,
                        (1 / norm_factor if config.scale_attention else 1.0),
                        config.triangular_masking, config.local_attention,
                        config.window_size, no_masking, config.layer_id,
                        DeepSpeedTransformerInference.layer_id,
                        sliced_alibi if alibi is not None else torch.empty(1))
                    context_layer, key_layer, value_layer = attn_key_value
                    return context_layer, key_layer, value_layer

        def selfAttention_fp():
            vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \
                                    inference_cuda_module.vector_matmul_fp32
            if not config.pre_layer_norm:
                linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \
                                    inference_cuda_module.linear_layer_fp32

                qkv_out = linear_func(input, attn_qkvw, attn_qkvb,
                                      DeepSpeedTransformerInference.layer_id)
            else:
                qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \
                                    inference_cuda_module.qkv_gemm_fp32
                qkv_out = qkv_func(
                    input, attn_qkvw,
                    (attn_qkvb if attn_qkvb is not None else norm_b), norm_w,
                    norm_b, config.epsilon, (attn_qkvb is not None),
                    1 if config.bigscience_bloom else
                    DeepSpeedTransformerInference.layer_id)
            context_layer, key_layer, value_layer = compute_attention(
                qkv_out[0] if isinstance(qkv_out, list) else qkv_out,
                input_mask)
            output = vector_matmul_func(context_layer, attn_ow, False)

            return output, key_layer, value_layer, context_layer, qkv_out[-1]

        def selfAttention_int8():
            if not config.pre_layer_norm:
                qkv_out = inference_cuda_module.linear_layer_int8(
                    input, attn_qkvw, attn_qkvb, q_scales[0],
                    (q_groups * (3 if qkv_merging else 1) * (2**merge_count)))

            else:
                qkv_out = inference_cuda_module.qkv_gemm_int8(
                    input, attn_qkvw, attn_qkvb, norm_w, norm_b,
                    config.epsilon, q_scales[0],
                    (q_groups * (3 if qkv_merging else 1) * (2**merge_count)),
                    (attn_qkvb is not None))
            context_layer, key_layer, value_layer = compute_attention(qkv_out)
            output = inference_cuda_module.vector_matmul_int8(
                context_layer, attn_ow, q_scales[1], q_groups, (merge_count))
            return output, key_layer, value_layer, context_layer

        if config.q_int8:
            output, key_layer, value_layer, context_layer = selfAttention_int8(
            )
        else:
            output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp(
            )
        if config.mlp_after_attn and mp_group is not None and dist.get_world_size(
                group=mp_group) > 1:
            dist.all_reduce(output, group=mp_group)

        return (output, key_layer, value_layer, context_layer, inp_norm)