Esempio n. 1
0
def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):

    if torch.distributed.get_rank() == 0:
        print('> testing get_tensor_model_parallel_src_rank with size {} ...'.
              format(tensor_model_parallel_size_))
    tensor_model_parallel_size = min(
        tensor_model_parallel_size_,
        torch.distributed.get_world_size(),
    )
    assert not parallel_state.model_parallel_is_initialized()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size)
    assert parallel_state.model_parallel_is_initialized()

    # Checks
    src_rank = torch.distributed.get_rank(
    ) - parallel_state.get_tensor_model_parallel_rank()
    assert parallel_state.get_tensor_model_parallel_src_rank() == src_rank
    split_rank = parallel_state.get_pipeline_model_parallel_split_rank()
    assert split_rank is None

    # Reset groups
    parallel_state.destroy_model_parallel()

    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('>> passed the test :-)')
Esempio n. 2
0
File: data.py Progetto: kexinyu/apex
def _build_key_size_numel_dictionaries(keys, data):
    """Build the size on rank 0 and broadcast."""
    max_dim = _MAX_DATA_DIM
    sizes = [0 for _ in range(max_dim) for _ in keys]

    # Pack the sizes on rank zero.
    if get_tensor_model_parallel_rank() == 0:
        offset = 0
        for key in keys:
            assert data[key].dim(
            ) < max_dim, "you should increase MAX_DATA_DIM"
            size = data[key].size()
            for i, s in enumerate(size):
                sizes[i + offset] = s
            offset += max_dim

    # Move to GPU and broadcast.
    sizes_cuda = torch.cuda.LongTensor(sizes)
    torch.distributed.broadcast(
        sizes_cuda,
        get_tensor_model_parallel_src_rank(),
        group=get_tensor_model_parallel_group(),
    )

    # Move back to cpu and unpack.
    sizes_cpu = sizes_cuda.cpu()
    key_size = {}
    key_numel = {}
    total_numel = 0
    offset = 0
    for key in keys:
        i = 0
        size = []
        numel = 1
        while sizes_cpu[offset + i] > 0:
            this_size = sizes_cpu[offset + i]
            size.append(this_size)
            numel *= this_size
            i += 1
        key_size[key] = size
        key_numel[key] = numel
        total_numel += numel
        offset += max_dim

    return key_size, key_numel, total_numel
Esempio n. 3
0
File: data.py Progetto: kexinyu/apex
def broadcast_data(keys, data, datatype):
    """Broadcast data from rank zero of each model parallel group to the
    members of the same model parallel group.

    Arguments:
        keys: list of keys in the data disctionary to be broadcasted
        data: data dictionary of string keys and cpu tensor values.
        datatype: torch data type of all tensors in data associated
                  with keys.
    """
    # Build (key, size) and (key, number of elements) dictionaries along
    # with the total number of elements on all ranks.
    key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(
        keys, data)
    # Pack on rank zero.
    if get_tensor_model_parallel_rank() == 0:
        # Check that all keys have the same data type.
        _check_data_types(keys, data, datatype)
        # Flatten the data associated with the keys
        flatten_data = torch.cat(
            [data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
    else:
        flatten_data = torch.empty(total_numel,
                                   device=torch.cuda.current_device(),
                                   dtype=datatype)

    # Broadcast
    torch.distributed.broadcast(
        flatten_data,
        get_tensor_model_parallel_src_rank(),
        group=get_tensor_model_parallel_group(),
    )

    # Unpack
    output = {}
    offset = 0
    for key in keys:
        size = key_size[key]
        numel = key_numel[key]
        output[key] = flatten_data.narrow(0, offset, numel).view(size)
        offset += numel

    return output
Esempio n. 4
0
    def test_initialize_model_parallel(self) -> None:

        self.assertFalse(parallel_state.model_parallel_is_initialized())

        for tensor_model_parallel_world_size in range(1, self.world_size + 1):
            with self.subTest(tensor_model_parallel_world_size=
                              tensor_model_parallel_world_size):
                if self.world_size % tensor_model_parallel_world_size:
                    continue

                pipeline_model_parallel_world_size = (
                    self.world_size // tensor_model_parallel_world_size)

                parallel_state.initialize_model_parallel(
                    tensor_model_parallel_size_=
                    tensor_model_parallel_world_size,
                    pipeline_model_parallel_size_=
                    pipeline_model_parallel_world_size,
                )
                self.assertEqual(
                    tensor_model_parallel_world_size,
                    parallel_state.get_tensor_model_parallel_world_size(),
                )
                expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
                    self.rank, tensor_model_parallel_world_size)
                self.assertEqual(
                    expected_tensor_model_parallel_rank,
                    parallel_state.get_tensor_model_parallel_rank(),
                )

                expected_tensor_model_parallel_src_rank = (
                    self.rank // tensor_model_parallel_world_size
                ) * tensor_model_parallel_world_size
                self.assertEqual(
                    expected_tensor_model_parallel_src_rank,
                    parallel_state.get_tensor_model_parallel_src_rank(),
                )

                parallel_state.destroy_model_parallel()
                self.assertFalse(
                    parallel_state.model_parallel_is_initialized())