def __init__(self, batch_sampler: torch.utils.data.BatchSampler, num_workers: int, rank: int) -> None: check.gt(rank, -1, "rank must be non-negative") check.gt(num_workers, 0, "num_workers must be positive") check.lt(rank, num_workers, "rank must be less than num_workers") self.batch_sampler = batch_sampler self.num_workers = num_workers self.rank = rank
def receive_non_blocking( self, send_rank: int, deadline: Optional[float] = None ) -> Tuple[bool, Any]: check.lt(send_rank, len(self.sockets)) timeout = 1000 if not deadline else int(deadline - time.time()) * 1000 timeout = max(timeout, 100) if self.sockets[send_rank].poll(timeout) == 0: # type: ignore return False, None message = self.sockets[send_rank].recv_pyobj() # type: ignore return True, message
def _bind_to_random_ports(self, port_range: Tuple[int, int], num_connections: int) -> None: check.lt(num_connections, port_range[1] - port_range[0]) for _ in range(num_connections): socket = self.context.socket(zmq.REP) try: selected_port = socket.bind_to_random_port( addr="tcp://*", min_port=port_range[0], max_port=port_range[1]) self.ports.append(selected_port) except ZMQBindError as e: raise det.errors.InternalException( f"Failed to bind to port range {port_range}.") from e self.sockets.append(socket)
def receive_blocking(self, send_rank: int) -> Any: check.lt(send_rank, len(self.sockets)) message = self.sockets[send_rank].recv_pyobj() # type: ignore return message