Ejemplo n.º 1
0
    def _test_sequence_num_incremented(self, process_group, ranks):
        # verify initial sequence numbers. Use a distinct process group for
        # verification to keep counts as expected with respect to process_group.
        verify_pg = dist.new_group(
            ranks=ranks,
            backend="gloo",
        )
        assert dist.get_world_size(process_group) == dist.get_world_size(verify_pg)

        initial_num = (
            self._verify_sequence_number_across_pg(
                pg=process_group, verify_pg=verify_pg
            )
            if not c10d._rank_not_in_group(process_group)
            else -1
        )

        # Verify sequence numbers are appropriately incremented
        for i in range(10):
            t = torch.ones(1, device=torch.cuda.current_device())
            dist.all_reduce(t, group=process_group)
            if not c10d._rank_not_in_group(process_group):
                seq_num = self._verify_sequence_number_across_pg(
                    pg=process_group,
                    verify_pg=verify_pg,
                )
                self.assertEqual(initial_num + i + 1, seq_num)

        if dist.get_world_size(process_group) > 2:
            # Test when certain ranks don't call collectives
            if dist.get_rank(process_group) not in [0, 2]:
                dist.all_reduce(t, group=process_group, async_op=True)
            # Now ranks 0 and 2 should be lagging by 1.
            if not c10d._rank_not_in_group(process_group):
                seq_num = process_group._get_sequence_number_for_group()
                rank = dist.get_rank(process_group)
                obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
                dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg)
                rank_to_seq_num = {rank: num for (rank, num) in obj_list}
                self.assertEqual(len(set(rank_to_seq_num.values())), 2)
                self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2])
                expected_same = {
                    rank_to_seq_num[i]
                    for i in rank_to_seq_num.keys()
                    if i not in [0, 2]
                }
                self.assertEqual(len(expected_same), 1)
                self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
Ejemplo n.º 2
0
def _parse_and_validate_remote_device(pg, remote_device):

    worker_name = remote_device.worker_name()
    rank = remote_device.rank()
    device = remote_device.device()

    # Validate rank, skip validation if rank is not part of process group.
    if not distributed_c10d._rank_not_in_group(pg):
        if rank is not None and (rank < 0 or
                                 rank >= distributed_c10d.get_world_size(pg)):
            raise ValueError(f'Invalid rank: {rank}')

    if worker_name is not None:
        if not rpc._is_current_rpc_agent_set():
            raise RuntimeError(
                f'RPC framework needs to be initialized for using worker names: {worker_name}'
            )

        workers = rpc._get_current_rpc_agent().get_worker_infos()
        for worker in workers:
            if worker.name == worker_name:
                return worker.id, device

        raise ValueError(f'Invalid worker name: {worker_name}')

    return rank, device
Ejemplo n.º 3
0
    def irecv(self, tensor, src=None, tag=0):
        # pylint: disable=protected-access

        # Original irecv doesn't support recv from any
        # but original recv does. They are essentially
        # the same except recv have a wait() call
        dist_c10d._check_single_tensor(tensor, "tensor")
        if dist_c10d._rank_not_in_group(self.group):
            return -1

        if self.group == dist_c10d.GroupMember.WORLD:
            dist_c10d._check_default_pg()
            pg = dist_c10d._default_pg
        else:
            pg = self.group

        if src is None:
            work = pg.recv_anysource([tensor], tag)
            src_rank = work.source_rank()
            if self.group == dist_c10d.GroupMember.WORLD:
                return src_rank
            else:
                return dist_c10d._get_global_rank(pg, src_rank)
        else:
            if self.group == dist_c10d.GroupMember.WORLD:
                pg.recv([tensor], src, tag).wait()
            else:
                group_src_rank = dist_c10d._get_group_rank(pg, src)
                pg.recv([tensor], group_src_rank, tag).wait()
            return src
Ejemplo n.º 4
0
def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None):
    if len(args) != 2:
        raise ValueError(
            f'Expected two arguments for torch.{cmp_fun.__name__}')

    result = True
    st1 = args[0]
    st2 = args[1]
    if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)):
        raise TypeError(
            f'Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor'
        )

    # Verify same PG
    if st1._process_group != st2._process_group:
        return False

    if distributed_c10d._rank_not_in_group(
            st1._process_group) or distributed_c10d._rank_not_in_group(
                st2._process_group):
        return distributed_c10d._rank_not_in_group(
            st1._process_group) == distributed_c10d._rank_not_in_group(
                st2._process_group)

    # Verify metadata
    if st1.metadata() != st2.metadata():
        return _communicate_result(False, st1._process_group)

    # Verify number of local shards
    st1_local_shards = st1.local_shards()
    st2_local_shards = st2.local_shards()
    if len(st1_local_shards) != len(st2_local_shards):
        return _communicate_result(False, st1._process_group)

    # kwargs must be dict-like
    if kwargs is None:
        kwargs = {}
    # Verify each local shard
    for idx in range(len(st1_local_shards)):
        if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata:
            return _communicate_result(False, st1._process_group)
        if not cmp_fun(st1_local_shards[idx].tensor,
                       st2_local_shards[idx].tensor, **kwargs):
            return _communicate_result(False, st1._process_group)

    return _communicate_result(True, st1._process_group)
Ejemplo n.º 5
0
    def irecv(self, tensor, src=None, tag=0):  # pragma: no cover
        """
        Returns:
            An object you can call .wait() on, .wait()
            will return the source rank.
        """
        # pylint: disable=protected-access

        # Original irecv doesn't support recv from any
        # but original recv does. They are essentially
        # the same except recv have a wait() call
        dist_c10d._check_single_tensor(tensor, "tensor")
        if dist_c10d._rank_not_in_group(self.group):
            class Waiter:
                def wait(self):
                    return -1

            return Waiter()

        if self.group == dist_c10d.GroupMember.WORLD:
            dist_c10d._check_default_pg()
            pg = dist_c10d._default_pg
        else:
            pg = self.group

        if src is None:
            work = pg.recv_anysource([tensor], tag)
            if self.group == dist_c10d.GroupMember.WORLD:
                class Waiter:
                    def wait(self):
                        nonlocal work
                        work.wait()
                        return work.source_rank()

                return Waiter()
            else:
                class Waiter:
                    def wait(self):
                        nonlocal work, pg
                        work.wait()
                        src_rank = work.source_rank()
                        return dist_c10d._get_global_rank(pg, src_rank)

                return Waiter()
        else:
            if self.group == dist_c10d.GroupMember.WORLD:
                work = pg.recv([tensor], src, tag)
            else:
                group_src_rank = dist_c10d._get_group_rank(pg, src)
                work = pg.recv([tensor], group_src_rank, tag)

            class Waiter:
                def wait(self):
                    nonlocal src
                    work.wait()
                    return src

            return Waiter()
Ejemplo n.º 6
0
def equal(types, args=(), kwargs=None, process_group=None):
    if len(args) != 2:
        raise ValueError('Expected two arguments for torch.equal')

    result = True
    st1 = args[0]
    st2 = args[1]
    if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)):
        raise TypeError(
            'Both arguments to torch.equal need to be of type ShardedTensor')

    # Verify same PG
    if st1._process_group != st2._process_group:
        return False

    if distributed_c10d._rank_not_in_group(
            st1._process_group) or distributed_c10d._rank_not_in_group(
                st2._process_group):
        return distributed_c10d._rank_not_in_group(
            st1._process_group) == distributed_c10d._rank_not_in_group(
                st2._process_group)

    # Verify metadata
    if st1.metadata() != st2.metadata():
        return _communicate_result(False, st1._process_group)

    # Verify number of local shards
    st1_local_shards = st1.local_shards()
    st2_local_shards = st2.local_shards()
    if len(st1_local_shards) != len(st2_local_shards):
        return _communicate_result(False, st1._process_group)

    # Verify each local shard
    for idx in range(len(st1_local_shards)):
        if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata:
            return _communicate_result(False, st1._process_group)
        if not torch.equal(st1_local_shards[idx].tensor,
                           st2_local_shards[idx].tensor):
            return _communicate_result(False, st1._process_group)

    return _communicate_result(True, st1._process_group)
Ejemplo n.º 7
0
    def _test_sequence_num_set_new_group(self, backend):
        store = dist.FileStore(self.file_name, self.world_size)
        dist.init_process_group(
            backend,
            world_size=self.world_size,
            rank=self.rank,
            store=store,
        )

        subgroup = dist.new_group([0, 1])

        if not c10d._rank_not_in_group(subgroup):
            subgroup_seq = subgroup._get_sequence_number_for_group()
            obj_list = [None for _ in range(dist.get_world_size(subgroup))]
            dist.all_gather_object(obj_list, subgroup_seq, group=subgroup)
            self.assertEqual(len(set(obj_list)), 1)
Ejemplo n.º 8
0
    def _prepare_init(self, process_group=None):
        self._rpc_initialized = False
        self._sharded_tensor_id = None
        if rpc._is_current_rpc_agent_set():
            # Validate PG and RPC ranks match.
            pg_rank = dist.get_rank()
            rpc_rank = rpc.get_worker_info().id
            if pg_rank != rpc_rank:
                raise ValueError(
                    f'Default ProcessGroup and RPC ranks must be '
                    f'the same for ShardedTensor, found process group rank: '
                    f'{pg_rank} and RPC rank: {rpc_rank}')

        self._process_group = (process_group if process_group is not None else
                               distributed_c10d._get_default_group())

        if distributed_c10d._rank_not_in_group(self._process_group):
            raise ValueError(
                f'Global rank: {dist.get_rank()} not part of process group')

        self._local_shards: List[Shard] = []
        self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {}
        self._sharding_metadata: List[ShardMetadata] = []
Ejemplo n.º 9
0
    def _parse_and_validate_remote_device(self, device):

        on, local_device = _parse_remote_device(device)

        # Validate rank, skip validation if rank is not part of process group.
        if not distributed_c10d._rank_not_in_group(self._process_group):
            if isinstance(on, int) and (
                    on < 0 or on >= dist.get_world_size(self._process_group)):
                raise ValueError(f'Invalid rank: {on}')

        if isinstance(on, str):
            if not rpc._is_current_rpc_agent_set():
                raise RuntimeError(
                    f'RPC framework needs to be initialized for using worker names: {on}'
                )

            workers = rpc._get_current_rpc_agent().get_worker_infos()
            for worker in workers:
                if worker.name == on:
                    return worker.id, local_device

            raise ValueError(f'Invalid worker name: {on}')

        return on, local_device
Ejemplo n.º 10
0
    def __init__(
        self,
        sharding_spec: ShardingSpec,
        *size,
        dtype=None,
        layout=torch.strided,
        requires_grad=False,
        pin_memory=False,
        memory_format=torch.contiguous_format,
        process_group=None,
    ):
        self._rpc_initialized = False
        self._sharded_tensor_id = None
        if rpc._is_current_rpc_agent_set():
            # Validate PG and RPC ranks match.
            pg_rank = dist.get_rank()
            rpc_rank = rpc.get_worker_info().id
            if pg_rank != rpc_rank:
                raise ValueError(
                    f'Default ProcessGroup and RPC ranks must be '
                    f'the same for ShardedTensor, found process group rank: '
                    f'{pg_rank} and RPC rank: {rpc_rank}'
                )

        if layout != torch.strided:
            raise ValueError('Only torch.strided layout is currently supported')

        if memory_format != torch.contiguous_format:
            raise ValueError('Only torch.contiguous_format memory_format is currently supported')

        self._sharding_spec = sharding_spec
        self._dims = list(size)
        self._process_group = (
            process_group
            if process_group is not None
            else distributed_c10d._get_default_group()
        )

        if distributed_c10d._rank_not_in_group(self._process_group):
            raise ValueError(f'Global rank: {dist.get_rank()} not part of process group')

        self._local_shards: List[Shard] = []
        self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {}
        self._sharding_metadata: List[ShardMetadata] = []
        if isinstance(self._sharding_spec, ChunkShardingSpec):
            self._init_chunked(
                dtype,
                layout,
                requires_grad,
                pin_memory,
                memory_format,
            )
        elif isinstance(self._sharding_spec, EnumerableShardingSpec):
            self._init_enumerable(
                dtype,
                layout,
                requires_grad,
                pin_memory,
                memory_format,
            )
        else:
            raise ValueError(f'Unsupported sharding_spec: {self._sharding_spec}')

        with _sharded_tensor_lock:
            global _sharded_tensor_current_id, _sharded_tensor_map
            self._sharded_tensor_id = _sharded_tensor_current_id
            _sharded_tensor_map[self._sharded_tensor_id] = self
            _sharded_tensor_current_id += 1

        # Initialize RPC if available.
        if rpc._is_current_rpc_agent_set():
            self._init_rpc()