def __init__(self, batch_sampler: torch.utils.data.BatchSampler,
                 num_workers: int, rank: int) -> None:
        check.gt(rank, -1, "rank must be non-negative")
        check.gt(num_workers, 0, "num_workers must be positive")
        check.lt(rank, num_workers, "rank must be less than num_workers")

        self.batch_sampler = batch_sampler
        self.num_workers = num_workers
        self.rank = rank
Example #2
0
    def receive_non_blocking(
        self, send_rank: int, deadline: Optional[float] = None
    ) -> Tuple[bool, Any]:
        check.lt(send_rank, len(self.sockets))
        timeout = 1000 if not deadline else int(deadline - time.time()) * 1000
        timeout = max(timeout, 100)

        if self.sockets[send_rank].poll(timeout) == 0:  # type: ignore
            return False, None
        message = self.sockets[send_rank].recv_pyobj()  # type: ignore
        return True, message
Example #3
0
 def _bind_to_random_ports(self, port_range: Tuple[int, int],
                           num_connections: int) -> None:
     check.lt(num_connections, port_range[1] - port_range[0])
     for _ in range(num_connections):
         socket = self.context.socket(zmq.REP)
         try:
             selected_port = socket.bind_to_random_port(
                 addr="tcp://*",
                 min_port=port_range[0],
                 max_port=port_range[1])
             self.ports.append(selected_port)
         except ZMQBindError as e:
             raise det.errors.InternalException(
                 f"Failed to bind to port range {port_range}.") from e
         self.sockets.append(socket)
Example #4
0
 def receive_blocking(self, send_rank: int) -> Any:
     check.lt(send_rank, len(self.sockets))
     message = self.sockets[send_rank].recv_pyobj()  # type: ignore
     return message