Esempio n. 1
0
    def test_cat_errors(self):
        with self.assertRaisesRegex(
            RuntimeError, 'All inputs need to be an instance of _PartialTensor'
        ):
            torch.cat([_PartialTensor(torch.rand(10)), torch.rand(10)])

        with self.assertRaisesRegex(
            RuntimeError, 'reduce_ops need to be the same'
        ):
            torch.cat([_PartialTensor(torch.rand(10)), _PartialTensor(torch.rand(10), reduce_op=dist.ReduceOp.MAX)])

        with self.assertRaisesRegex(
            RuntimeError, '"out" kwarg is not supported'
        ):
            torch.cat([_PartialTensor(torch.rand(10)), _PartialTensor(torch.rand(10))], out=torch.rand(10))
Esempio n. 2
0
def _handle_row_wise_sharding_sharded_tensor(input, world_size, weight,
                                             local_shard_t, bias, pg):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for Linear when the input is a sharded tensor. (Detailed explanations
    of the logic can be found in the comment for sharded_linear.)

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns:
        A :class:`_PartialTensor` object which stores the partial local result.
    """
    results = []
    local_shard = input.local_shards()[0].tensor
    if input.sharding_spec().dim not in (-1, len(input.size()) - 1):
        raise NotImplementedError(
            "The case when the input does not come from col-wise sharded "
            "linear is not supported for row-wise sharded linear.")

    for tensor in torch.tensor_split(local_shard, world_size):
        results.append(
            tensor.matmul(local_shard_t) +
            _BiasTensorPartial.apply(world_size, bias))

    # Return the partial local result.
    return _PartialTensor(torch.cat(results), pg)
Esempio n. 3
0
 def _run_partial_tensor_n_reshard(
     self, reshard_spec, input_size, world_size, reduce_op, dtype=torch.float
 ):
     results_compare = []
     local_result = []
     pg = dist.distributed_c10d._get_default_group()
     for rank in range(pg.size()):
         torch.manual_seed(rank)
         results = []
         for _ in range(world_size):
             tensor = torch.rand(*input_size, dtype=dtype).cuda(self.rank)
             results.append(tensor)
             if self.rank == rank:
                 local_result.append(tensor.clone().detach())
         results_compare.append(torch.cat(results))
     parital_tensor = _PartialTensor(
         torch.cat(local_result), pg, reduce_op=reduce_op
     )
     local_sharded_result = parital_tensor.reshard(reshard_spec)
     local_shards = local_sharded_result.local_shards()
     results_compare = torch.stack(results_compare)
     if reduce_op == dist.ReduceOp.SUM:
         results_compare = torch.sum(results_compare, dim=0)
     else:
         results_compare = torch.max(results_compare, dim=0).values
     rank_idx = None
     for idx, placement in enumerate(reshard_spec.placements):
         if placement.rank() == self.rank:
             rank_idx = idx
     local_result_compare = results_compare.chunk(pg.size())[rank_idx]
     self.assertEqual(1, len(local_shards))
     self.assertEqual(local_shards[0].tensor, local_result_compare)
Esempio n. 4
0
    def test_cat(self):
        t1 = torch.rand(5, 10)
        t2 = torch.rand(3, 10)
        t3 = torch.rand(4, 10)
        partial_tensors = [
            _PartialTensor(t1),
            _PartialTensor(t2),
            _PartialTensor(t3)
        ]
        partial_concat = torch.cat(partial_tensors)
        local_concat = torch.cat([t1, t2, t3])
        self.assertEqual(local_concat.size(), partial_concat.size())

        # Test dim kwarg
        t1 = torch.rand(5, 10)
        t2 = torch.rand(5, 12)
        t3 = torch.rand(5, 11)
        partial_tensors = [
            _PartialTensor(t1),
            _PartialTensor(t2),
            _PartialTensor(t3)
        ]
        partial_concat = torch.cat(partial_tensors, dim=1)
        local_concat = torch.cat([t1, t2, t3], dim=1)
        self.assertEqual(local_concat.size(), partial_concat.size())
Esempio n. 5
0
def _handle_row_wise_sharding_tensor(input, world_size, weight, rank,
                                     local_shard_t, bias, pg):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns:
        A :class:`_PartialTensor` object which stores the partial local result.
    """
    # alltoall to gather all the appropriate inputs.
    input_t = input.transpose(0, -1).contiguous()
    input_t_size = input_t.size()

    # Compute expected size
    split_size = get_split_size(input_t_size[0], world_size)
    input_split_sizes = [0] * world_size
    rearrange_rows = False

    for idx, placement in enumerate(weight._sharding_spec.placements):
        sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size,
                                                idx)
        input_split_sizes[placement.rank()] = sharded_dim_size
        if placement.rank() != idx:
            rearrange_rows = True

    if rearrange_rows:
        # Need to re-arrange rows of input_t for all2all.
        indices: List[List[int]] = [[0]] * world_size
        # When we do the chunk split, we always ensure the first N - 1 chunks get max out
        # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
        # are not possible. The expected split size will be [4, 4, 4, 1].
        sharded_dim_size_max = max(input_split_sizes)
        for idx, placement in enumerate(weight._sharding_spec.placements):
            split_size = input_split_sizes[placement.rank()]
            offset_start_idx = idx * sharded_dim_size_max
            indices[placement.rank()] = list(
                range(offset_start_idx, offset_start_idx + split_size))
        indices_flatten = list(idx for indice in indices for idx in indice)

        input_t = input_t.index_select(
            0, torch.tensor(indices_flatten, device=input_t.device))

    gathered_input_size = [input_split_sizes[rank] * world_size] + list(
        input_t_size[1:])
    gathered_input = torch.empty(gathered_input_size, device=input_t.device)

    # Perform autograd enabled alltoall
    all_to_all_single(gathered_input,
                      input_t,
                      input_split_sizes=input_split_sizes,
                      group=pg)
    gathered_input = gathered_input.transpose(0, -1)

    # Perform local matmuls for all shards
    results = []
    shard_size = local_shard_t.size()[0]
    for r in range(world_size):
        inp = torch.narrow(gathered_input, -1, r * shard_size, shard_size)
        results.append(
            inp.matmul(local_shard_t) +
            _BiasTensorPartial.apply(world_size, bias))

    # Return the partial local result.
    return _PartialTensor(torch.cat(results), pg)
Esempio n. 6
0
 def test_transpose(self):
     partial_tensor = _PartialTensor(torch.rand(5, 10))
     partial_tensor = partial_tensor.transpose(0, 1)
     self.assertEqual(partial_tensor.size(), torch.Size((10, 5)))