def _data_block_handler(self, request): assert self._connected, "Cannot load data before connect" if not self._data_block_handler_fn: raise RuntimeError("Received DataBlockMessage but" \ " no handler registered") metrics.emit_counter('load_data_block_counter', 1) if self._data_block_handler_fn(request): return common_pb.Status(code=common_pb.STATUS_SUCCESS) metrics.emit_counter('load_data_block_fail_counter', 1) return common_pb.Status(code=common_pb.STATUS_INVALID_DATA_BLOCK)
def iterator(): with lock: resend_msgs = list(resend_list) for item in resend_msgs: logging.warning("Streaming resend message seq_num=%d", item.seq_num) metrics.emit_counter("resend_counter", 1) yield item while True: item = self._transmit_queue.get() with lock: resend_list.append(item) logging.debug("Streaming send message seq_num=%d", item.seq_num) yield item
def _transmit(self, msg): assert self._connected, "Cannot transmit before connect" metrics.emit_counter('send_counter', 1) with self._transmit_send_lock: msg.seq_num = self._next_send_seq_num self._next_send_seq_num += 1 if self._streaming_mode: self._transmit_queue.put(msg) return def sender(): rsp = self._client.Transmit(msg) assert rsp.status.code == common_pb.STATUS_SUCCESS, \ "Transmit error with code %d."%rsp.status.code self._rpc_with_retry(sender, "Bridge transmit failed")
def _rpc_with_retry(self, sender, err_log): while True: with self._client_lock: try: return sender() except Exception as e: # pylint: disable=broad-except logging.warning( "%s: %s. Retry in 1s...", err_log, repr(e)) metrics.emit_counter('reconnect_counter', 1) self._channel.close() time.sleep(1) self._channel = make_insecure_channel( self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) self._client = make_ready_client(self._channel) self._check_remote_heartbeat(self._client)
def _transmit_handler(self, request): assert self._connected, "Cannot transmit before connect" metrics.emit_counter('receive_counter', 1) with self._transmit_receive_lock: logging.debug("Received message seq_num=%d." " Wanted seq_num=%d.", request.seq_num, self._next_receive_seq_num) if request.seq_num > self._next_receive_seq_num: return tws_pb.TrainerWorkerResponse( status=common_pb.Status( code=common_pb.STATUS_MESSAGE_MISSING), next_seq_num=self._next_receive_seq_num) if request.seq_num < self._next_receive_seq_num: return tws_pb.TrainerWorkerResponse( status=common_pb.Status( code=common_pb.STATUS_MESSAGE_DUPLICATED), next_seq_num=self._next_receive_seq_num) # request.seq_num == self._next_receive_seq_num self._next_receive_seq_num += 1 if request.HasField('start'): with self._condition: self._received_data[request.start.iter_id] = {} elif request.HasField('commit'): pass elif request.HasField('data'): with self._condition: assert request.data.iter_id in self._received_data self._received_data[ request.data.iter_id][ request.data.name] = request.data self._condition.notifyAll() elif request.HasField('prefetch'): for func in self._prefetch_handlers: func(request.prefetch) else: return tws_pb.TrainerWorkerResponse( status=common_pb.Status( code=common_pb.STATUS_INVALID_REQUEST), next_seq_num=self._next_receive_seq_num) return tws_pb.TrainerWorkerResponse( next_seq_num=self._next_receive_seq_num)
def _client_daemon_fn(self): stop_event = threading.Event() generator = None channel = make_insecure_channel( self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) client = make_ready_client(channel, stop_event) lock = threading.Lock() resend_list = collections.deque() def shutdown_fn(): with lock: while len(resend_list) > 0 or not self._transmit_queue.empty(): logging.debug( "Waiting for resend queue's being cleaned. " "Resend queue size: %d", len(resend_list)) lock.release() time.sleep(1) lock.acquire() stop_event.set() if generator is not None: generator.cancel() self._client_daemon_shutdown_fn = shutdown_fn while not stop_event.is_set(): try: def iterator(): with lock: resend_msgs = list(resend_list) for item in resend_msgs: logging.warning("Streaming resend message seq_num=%d", item.seq_num) metrics.emit_counter("resend_counter", 1) yield item while True: item = self._transmit_queue.get() with lock: resend_list.append(item) logging.debug("Streaming send message seq_num=%d", item.seq_num) yield item generator = client.StreamTransmit(iterator()) for response in generator: if response.status.code == common_pb.STATUS_SUCCESS: logging.debug("Message with seq_num=%d is " "confirmed", response.next_seq_num-1) elif response.status.code == \ common_pb.STATUS_MESSAGE_DUPLICATED: logging.debug("Resent Message with seq_num=%d is " "confirmed", response.next_seq_num) elif response.status.code == \ common_pb.STATUS_MESSAGE_MISSING: raise RuntimeError("Message with seq_num=%d is " "missing!" % (response.next_seq_num)) else: raise RuntimeError("Trainsmit failed with %d" % response.status.code) with lock: while resend_list and \ resend_list[0].seq_num < response.next_seq_num: resend_list.popleft() min_seq_num_to_resend = resend_list[0].seq_num \ if resend_list else "NaN" logging.debug( "Resend queue size: %d, starting from seq_num=%s", len(resend_list), min_seq_num_to_resend) except Exception as e: # pylint: disable=broad-except if not stop_event.is_set(): logging.warning("Bridge streaming broken: %s.", repr(e)) metrics.emit_counter('reconnect_counter', 1) finally: generator.cancel() channel.close() time.sleep(1) logging.warning( "Restarting streaming: resend queue size: %d, " "starting from seq_num=%s", len(resend_list), resend_list and resend_list[0].seq_num or "NaN") channel = make_insecure_channel( self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) client = make_ready_client(channel, stop_event) self._check_remote_heartbeat(client)
def _client_daemon_fn(self): stop_event = threading.Event() generator = None channel = make_insecure_channel(self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) client = make_ready_client(channel, stop_event) def shutdown_fn(): while self._transmit_queue.size(): logging.debug( "Waiting for message queue's being cleaned. " "Queue size: %d", self._transmit_queue.size()) time.sleep(1) stop_event.set() if generator is not None: generator.cancel() self._client_daemon_shutdown_fn = shutdown_fn while not stop_event.is_set(): try: def iterator(): while True: item = self._transmit_queue.get() logging.debug("Streaming send message seq_num=%d", item.seq_num) yield item generator = client.StreamTransmit(iterator()) for response in generator: if response.status.code == common_pb.STATUS_SUCCESS: self._transmit_queue.confirm(response.next_seq_num) logging.debug( "Message with seq_num=%d is " "confirmed", response.next_seq_num - 1) elif response.status.code == \ common_pb.STATUS_MESSAGE_DUPLICATED: self._transmit_queue.confirm(response.next_seq_num) logging.debug( "Resent Message with seq_num=%d is " "confirmed", response.next_seq_num - 1) elif response.status.code == \ common_pb.STATUS_MESSAGE_MISSING: self._transmit_queue.resend(response.next_seq_num) else: raise RuntimeError("Trainsmit failed with %d" % response.status.code) except Exception as e: # pylint: disable=broad-except if not stop_event.is_set(): logging.warning("Bridge streaming broken: %s.", repr(e)) metrics.emit_counter('reconnect_counter', 1) finally: generator.cancel() channel.close() time.sleep(1) self._transmit_queue.resend(-1) channel = make_insecure_channel(self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) client = make_ready_client(channel, stop_event) self._check_remote_heartbeat(client)