def bench_vanilla_all_gather(nelem, device, repeat=100, warm_up=10):
    torch.manual_seed(1)
    nranks = dist.get_world_size()
    rank = dist.get_rank()

    partition_size = _divup(nelem, nranks)
    nelem = nranks * partition_size  # make sure dividable
    data_tensor = torch.rand(nelem, device=device, dtype=torch.half)

    all_gather_output = data_tensor
    start = rank * partition_size
    all_gather_input = data_tensor.narrow(0, start, partition_size)

    time_costs = []
    for i in range(repeat + warm_up):
        _start_time = time.time()

        dist._all_gather_base(all_gather_output, all_gather_input, group=None)
        torch.cuda.synchronize()
        _end_time = time.time()

        if i >= warm_up:
            time_costs.append((_end_time - _start_time) * 1e3)
    if dist.get_rank() == 0:
        print(
            f'all-gather from all ranks average costs {np.mean(time_costs)} ms, std {np.std(time_costs)}'
        )
Ejemplo n.º 2
0
    def _rebuild_full_params(self) -> None:
        """
        Gather all shards of params.
        """
        def update_p_data(output_tensor: torch.Tensor) -> None:
            """
            Helper function to update p.data pointer.
            Args:
                output_tensor (torch.Tensor): this tensor contains the data we just gathered.
            """
            p.data = output_tensor
            # Trim any padding and reshape to match original size.
            p.data = p.data[:p._orig_size.numel()].view(
                p._orig_size)  # type: ignore[attr-defined]

        with torch.cuda.stream(self._streams["all_gather"]):
            for p in self.params:
                # e.g., when world_size == 1
                if not p._is_sharded:  # type: ignore[attr-defined]
                    continue
                # If full param has been rebuilt or has not been freed, no need to call all gather
                elif (p._full_param_padded.storage().size(
                )  # type: ignore[attr-defined]
                      == p._full_param_padded.size().numel(
                      )  # type: ignore[attr-defined]
                      ):
                    update_p_data(
                        p._full_param_padded)  # type: ignore[attr-defined]
                    continue
                else:
                    # If full param has not been rebuilt or has been freed, call all gather
                    # Move params in CPU to CUDA for the all-gather.
                    p_data = p.data  # type: ignore[attr-defined]
                    p_full_size = p._full_param_padded.size(
                    )  # type: ignore[attr-defined]
                    assert (
                        p_full_size.numel() == p_data.numel() * self.world_size
                    ), "Param full size should be equal to its shard size multiply world_size."
                    assert (
                        p._full_param_padded.storage().size() ==
                        0  # type: ignore[attr-defined]
                    ), "Full param's storage should have been freed before if all gather is needed."  # type: ignore[attr-defined]
                    # Allocate based on full size from all shards.
                    _alloc_storage(
                        p._full_param_padded,
                        size=p_full_size)  # type: ignore[attr-defined]
                    output_tensor = p._full_param_padded  # type: ignore[attr-defined]

                    # Fill output_tensor with (p.data for each shard in self.world_size)
                    dist._all_gather_base(output_tensor,
                                          p_data,
                                          group=self.process_group)

                    # Set p.data = output_tensor (with padding trimmed)
                    update_p_data(output_tensor)

        torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
Ejemplo n.º 3
0
def timed_all_gather(input, output, args):
    if args.dist == 'torch':
        import torch.distributed as dist
    elif args.dist == 'deepspeed':
        import deepspeed.comm as dist

    sync_all()
    # Warmups, establish connections, etc.
    for i in range(args.warmups):
        # use all_gather_base if available
        if args.dist == 'torch':
            if hasattr(torch.distributed, "_all_gather_base"):
                dist._all_gather_base(output, input, group=None, async_op=args.async_op)
            else:
                output_tensors = list(
                    torch.chunk(output_tensor,
                                cdb.get_world_size(group)))
                dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
        elif args.dist == 'deepspeed':
            dist.allgather_fn(output, input, group=None, async_op=args.async_op)
    sync_all()

    # time the actual comm op trials times and average it
    pre = time.perf_counter()
    for i in range(args.trials):
        # use all_gather_base if available
        if args.dist == 'torch':
            if hasattr(torch.distributed, "_all_gather_base"):
                dist._all_gather_base(output, input, group=None, async_op=args.async_op)
            else:
                output_tensors = list(
                    torch.chunk(output_tensor,
                                cdb.get_world_size(group)))
                dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
        elif args.dist == 'deepspeed':
            dist.allgather_fn(output, input, group=None, async_op=args.async_op)
    sync_all()
    duration = time.perf_counter() - pre

    # maintain and clean performance data
    avg_duration = duration / args.trials
    size = input.element_size() * input.nelement()
    n = dist.get_world_size()
    tput, busbw = get_bw('all_gather', size, avg_duration, args)
    tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
    desc = f'{input.nelement()}x{input.element_size()}'

    if not args.raw:
        size = convert_size(size)

    print_rank_0(
        f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
def bench_hierarchy_all_gather(nelem,
                               device,
                               intra_shard_group,
                               inter_shard_group,
                               repeat=100,
                               warm_up=10):
    torch.manual_seed(1)
    nranks = dist.get_world_size()
    rank = dist.get_rank()
    local_size = torch.cuda.device_count()
    local_rank = rank % local_size

    partition_size = _divup(nelem, nranks)
    nelem = nranks * partition_size  # make sure dividable

    data_tensor = torch.rand(nelem, dtype=torch.half, device=device)

    inter_node_size = dist.get_world_size(group=inter_shard_group)
    inter_node_part_size = inter_node_size * partition_size
    start = inter_node_part_size * local_rank
    inter_node_output = data_tensor.narrow(0, start, inter_node_part_size)
    inter_node_shard_rank = dist.get_rank(group=inter_shard_group)
    start = inter_node_shard_rank * partition_size
    inter_node_input = inter_node_output.narrow(0, start, partition_size)

    # input of intra-node all-gather is output from the inter_node all-gather
    intra_node_input = inter_node_output
    # intra-node output produce the complete parameter tensor
    intra_node_output = data_tensor

    time_costs = []
    for i in range(repeat + warm_up):
        _start_time = time.time()
        # inter-node
        dist._all_gather_base(inter_node_output,
                              inter_node_input,
                              group=inter_shard_group)

        # intra-node
        dist._all_gather_base(intra_node_output,
                              intra_node_input,
                              group=intra_shard_group)
        # device sync
        torch.cuda.synchronize()
        _end_time = time.time()
        if i >= warm_up:
            time_costs.append((_end_time - _start_time) * 1e3)
    if dist.get_rank() == 0:
        print(
            f'average cost of hierarchy all-gather {np.mean(time_costs)} ms, std {np.std(time_costs)}'
        )
Ejemplo n.º 5
0
def _all_gather_sharded_tensor(
        sharded_tensor: ShardedTensor,
        pg: Optional[dist.ProcessGroup] = None) -> torch.Tensor:
    if pg is None:
        pg = distributed_c10d._get_default_group()
    world_size = dist.get_world_size(pg)
    shards = sharded_tensor.local_shards()
    local_tensor = shards[0].tensor.flatten()
    dim_0_size = sharded_tensor.size()[0]  # type: ignore[index]
    tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr]
    chunk_size = math.ceil(
        dim_0_size / world_size) * tensor_numel // dim_0_size
    num_padding = chunk_size - local_tensor.numel()
    if num_padding > 0:
        local_tensor = F.pad(local_tensor, [0, num_padding])
    tensor = torch.empty(chunk_size * world_size,
                         dtype=local_tensor.dtype).cuda()
    dist._all_gather_base(tensor, local_tensor, group=pg)
    return tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
Ejemplo n.º 6
0
    def all_gather_base(self, collectiveArgs, retFlag=False, pair=False):
        retObj = dist._all_gather_base(
            output_tensor=collectiveArgs.opTensor,
            input_tensor=collectiveArgs.ipTensor,
            group=collectiveArgs.group,
            async_op=collectiveArgs.asyncOp,
        )  # synchronicity is maintained in runColl

        if collectiveArgs.asyncOp:
            collectiveArgs.waitObj.append(retObj)

        if retFlag:
            return retObj
Ejemplo n.º 7
0
def _all_gather_sharded_tensor(
        sharded_tensor: ShardedTensor,
        pg: Optional[dist.ProcessGroup] = None) -> torch.Tensor:
    if pg is None:
        pg = distributed_c10d._get_default_group()
    world_size = dist.get_world_size(pg)
    shards = sharded_tensor.local_shards()
    dim_0_size = sharded_tensor.size()[0]  # type: ignore[index]
    tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr]
    chunk_size = math.ceil(
        dim_0_size / world_size) * tensor_numel // dim_0_size
    cuda_device = torch.device("cuda", torch.cuda.current_device())
    if shards:
        local_tensor = shards[0].tensor.flatten()
        if not local_tensor.is_cuda:
            move_to_cpu = torch.ones(1, device=cuda_device)
            local_tensor = local_tensor.cuda()
        else:
            move_to_cpu = torch.zeros(1, device=cuda_device)
        num_padding = chunk_size - local_tensor.numel()
        if num_padding > 0:
            local_tensor = F.pad(local_tensor, [0, num_padding])
    else:
        local_tensor = torch.zeros(chunk_size,
                                   dtype=sharded_tensor.dtype,
                                   device=cuda_device)
        move_to_cpu = torch.zeros(1, device=cuda_device)

    tensor = torch.empty(
        chunk_size * world_size,
        dtype=local_tensor.dtype,
        device=cuda_device,
    )
    dist._all_gather_base(tensor, local_tensor, group=pg)

    tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
    return tensor
Ejemplo n.º 8
0
    def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
        if not input.is_contiguous(memory_format=torch.channels_last):
            input = input.contiguous()
        if weight is not None:
            weight = weight.contiguous()

        size = int(input.numel() // input.size(1))
        if size == 1 and world_size < 2:
            raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

        num_channels = input.shape[1]
        if input.numel() > 0:
            # calculate mean/invstd for input.
            mean, invstd = torch.batch_norm_stats(input, eps)

            count = torch.full(
                (1,),
                input.numel() // input.size(1),
                dtype=mean.dtype,
                device=mean.device
            )

            # C, C, 1 -> (2C + 1)
            combined = torch.cat([mean, invstd, count], dim=0)
        else:
            # for empty input, set stats and the count to zero. The stats with
            # zero count will be filtered out later when computing global mean
            # & invstd, but they still needs to participate the all_gather
            # collective communication to unblock other peer processes.
            combined = torch.zeros(
                2 * num_channels + 1,
                dtype=input.dtype,
                device=input.device
            )

        # Use allgather instead of allreduce because count could be different across
        # ranks, simple all reduce op can not give correct results.
        # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
        # all gathered mean, invstd and count.
        # for nccl backend, use the optimized version of all gather.
        if process_group._get_backend_name() == 'nccl':
            # world_size * (2C + 1)
            combined_size = combined.numel()
            combined_flat = torch.empty(1,
                                        combined_size * world_size,
                                        dtype=combined.dtype,
                                        device=combined.device)
            dist._all_gather_base(combined_flat, combined, process_group, async_op=False)
            combined = torch.reshape(combined_flat, (world_size, combined_size))
            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
            mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
        else:
            # world_size * (2C + 1)
            combined_list = [
                torch.empty_like(combined) for _ in range(world_size)
            ]
            dist.all_gather(combined_list, combined, process_group, async_op=False)
            combined = torch.stack(combined_list, dim=0)
            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
            mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)

        # remove stats from empty inputs
        mask = count_all.squeeze(-1) >= 1
        count_all = count_all[mask]
        mean_all = mean_all[mask]
        invstd_all = invstd_all[mask]

        # calculate global mean & invstd
        mean, invstd = torch.batch_norm_gather_stats_with_counts(
            input,
            mean_all,
            invstd_all,
            running_mean,
            running_var,
            momentum,
            eps,
            count_all.view(-1)
        )

        self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
        self.process_group = process_group

        # apply element-wise normalization
        if input.numel() > 0:
            return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
        else:
            return torch.empty_like(input)
Ejemplo n.º 9
0
def _communicate_optim_state(
    fsdp_module,
    flat_param: FlatParameter,
    flat_param_state: Dict[str, Any],
    to_save: bool,
) -> ConsolidatedOptimState:
    """
    Communicates the optimizer state for a flattened parameter ``flat_param``
    across ranks so that the target rank holds the entire non-sharded optimizer
    state.

    If ``N`` is the number of tensor optimizer states in the optimizer state
    dict, then the communication complexity is 0 if ``N = 0`` and ``N + 1``
    otherwise (where the plus 1 comes from all-gathering the padding per rank).

    Args:
        flat_param (FlatParameter): The flattened parameter.
        flat_param_state (Dict[str, Any]): The entry in the "state" part of the
            optimizer state dict corresponding to the flattened parameter.
        to_save (bool): Whether to save the state on this rank.

    Returns:
        ConsolidatedOptimState: Consolidated optimizer state for
        ``flat_param``; the state is not populated for non-target ranks.
    """
    param_index = -1
    for i, param in enumerate(fsdp_module.params):
        if param is flat_param:
            param_index = i
            break
    assert param_index >= 0, "`fsdp_module` must own `flat_param`"

    state = ConsolidatedOptimState()
    tensor_state, zero_dim_tensor_state, non_tensor_state = \
        state.tensor_state, state.zero_dim_tensor_state, state.non_tensor_state
    process_group = fsdp_module.process_group

    tensor_buffer = None  # initialize lazily in case it is not needed
    for state_name, value in flat_param_state.items():
        # Positive-dimension tensor state: communicate across ranks
        if torch.is_tensor(value) and value.dim() > 0:
            # If the parameter is not sharded (e.g. world size of 1), then
            # neither is the positive-dimension tensor state, so no need to
            # communicate it -- we take the target rank's value
            if not flat_param._is_sharded:
                tensor_state[state_name] = value.cpu()
                continue
            if tensor_buffer is None:
                # Assume that positive-dimension tensor optimizer state
                # has the same shape as the sharded flattened parameter
                buffer_size = flat_param._full_param_padded.size()  # type: ignore[attr-defined]
                tensor_buffer = value.new_zeros(*buffer_size)
            dist._all_gather_base(tensor_buffer, value, group=process_group)
            if to_save:
                assert hasattr(flat_param, "_orig_size"), \
                    "Sharded flattened parameter should have `_orig_size` set"
                unpadded_numel = flat_param._orig_size.numel()  # type: ignore[attr-defined]
                tensor_state[state_name] = tensor_buffer[:unpadded_numel].cpu()
        # Zero-dimension tensor state and non-tensor state: take this rank's
        # value directly
        elif to_save:
            if _is_zero_dim_tensor(value):
                zero_dim_tensor_state[state_name] = value.cpu()
            else:
                non_tensor_state[state_name] = value
    return state
Ejemplo n.º 10
0
 def forward(ctx, output_tensor, input_tensor, group):
     ctx.group = group
     dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group)
     return output_tensor
Ejemplo n.º 11
0
    def _rebuild_full_params(self) -> None:
        """
        Gather all shards of params.
        """
        def update_p_data(output_tensor: torch.Tensor) -> None:
            """
            Helper function to update p.data pointer.
            Args:
                output_tensor (torch.Tensor): this tensor contains the data we just gathered.
            """
            p.data = output_tensor
            # Trim any padding and reshape to match original size.
            p.data = p.data[:p._orig_size.numel()].view(
                p._orig_size)  # type: ignore[attr-defined]

        with torch.cuda.stream(self._streams["all_gather"]):
            for p in self.params:
                if self.cpu_offload.offload_params:
                    # Move params to GPU if needed. Note that we don't use
                    # self._full_param_padded.device here because the attr is
                    # not set always, i.e. when world_size=1 and
                    # p._is_sharded = False. However when it is set, the
                    # device is always self.compute_device.
                    p.data = p.data.to(self.compute_device, non_blocking=True)
                # e.g., when world_size == 1
                if not p._is_sharded:  # type: ignore[attr-defined]
                    continue
                # If full param has been rebuilt or has not been freed, no need to call all gather
                elif (p._full_param_padded.storage().size(
                )  # type: ignore[attr-defined]
                      == p._full_param_padded.size().numel(
                      )  # type: ignore[attr-defined]
                      ):
                    update_p_data(
                        p._full_param_padded)  # type: ignore[attr-defined]
                    continue
                else:
                    # If full param has not been rebuilt or has been freed, call all gather
                    p_data = p.data  # type: ignore[attr-defined]
                    p_full_size = p._full_param_padded.size(
                    )  # type: ignore[attr-defined]
                    assert (
                        p_full_size.numel() == p_data.numel() * self.world_size
                    ), "Param full size should be equal to its shard size multiply world_size."
                    assert (
                        p._full_param_padded.storage().size() ==
                        0  # type: ignore[attr-defined]
                    ), "Full param's storage should have been freed before if all gather is needed."  # type: ignore[attr-defined]
                    # Allocate based on full size from all shards.
                    _alloc_storage(
                        p._full_param_padded,
                        size=p_full_size)  # type: ignore[attr-defined]
                    output_tensor = p._full_param_padded  # type: ignore[attr-defined]

                    # Fill output_tensor with (p.data for each shard in self.world_size)
                    dist._all_gather_base(output_tensor,
                                          p_data,
                                          group=self.process_group)

                    # Set p.data = output_tensor (with padding trimmed)
                    update_p_data(output_tensor)

        torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
Ejemplo n.º 12
0
    def forward(self, input, weight, bias, running_mean, running_var, eps,
                momentum, process_group, world_size):
        if not input.is_contiguous(memory_format=torch.channels_last):
            input = input.contiguous()
        if weight is not None:
            weight = weight.contiguous()

        size = int(input.numel() // input.size(1))
        if size == 1 and world_size < 2:
            raise ValueError(
                'Expected more than 1 value per channel when training, got input size {}'
                .format(size))

        # calculate mean/invstd for input.
        mean, invstd = torch.batch_norm_stats(input, eps)

        count = torch.full((1, ),
                           input.numel() // input.size(1),
                           dtype=mean.dtype,
                           device=mean.device)

        num_channels = input.shape[1]
        # C, C, 1 -> (2C + 1)
        combined = torch.cat([mean, invstd, count], dim=0)
        # Use allgather instead of allreduce because count could be different across
        # ranks, simple all reduce op can not give correct results.
        # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
        # all gathered mean, invstd and count.
        # for nccl backend, use the optimized version of all gather.
        if process_group._get_backend_name() == 'nccl':
            # world_size * (2C + 1)
            combined_size = combined.numel()
            combined_flat = torch.empty(1,
                                        combined_size * world_size,
                                        dtype=combined.dtype,
                                        device=combined.device)
            dist._all_gather_base(combined_flat,
                                  combined,
                                  process_group,
                                  async_op=False)
            combined = torch.reshape(combined_flat,
                                     (world_size, combined_size))
            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
            mean_all, invstd_all, count_all = torch.split(combined,
                                                          num_channels,
                                                          dim=1)
        else:
            # world_size * (2C + 1)
            combined_list = [
                torch.empty_like(combined) for k in range(world_size)
            ]
            dist.all_gather(combined_list,
                            combined,
                            process_group,
                            async_op=False)
            combined = torch.stack(combined_list, dim=0)
            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
            mean_all, invstd_all, count_all = torch.split(combined,
                                                          num_channels,
                                                          dim=1)

        # calculate global mean & invstd
        mean, invstd = torch.batch_norm_gather_stats_with_counts(
            input, mean_all, invstd_all, running_mean, running_var, momentum,
            eps, count_all.view(-1))

        self.save_for_backward(input, weight, mean, invstd,
                               count_all.to(torch.int32))
        self.process_group = process_group

        # apply element-wise normalization
        out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
        return out