Example #1
0
 def _data_block_handler(self, request):
     assert self._connected, "Cannot load data before connect"
     if not self._data_block_handler_fn:
         raise RuntimeError("Received DataBlockMessage but" \
                             " no handler registered")
     metrics.emit_counter('load_data_block_counter', 1)
     if self._data_block_handler_fn(request):
         return common_pb.Status(code=common_pb.STATUS_SUCCESS)
     metrics.emit_counter('load_data_block_fail_counter', 1)
     return common_pb.Status(code=common_pb.STATUS_INVALID_DATA_BLOCK)
Example #2
0
 def iterator():
     with lock:
         resend_msgs = list(resend_list)
     for item in resend_msgs:
         logging.warning("Streaming resend message seq_num=%d",
                         item.seq_num)
         metrics.emit_counter("resend_counter", 1)
         yield item
     while True:
         item = self._transmit_queue.get()
         with lock:
             resend_list.append(item)
         logging.debug("Streaming send message seq_num=%d",
                       item.seq_num)
         yield item
Example #3
0
    def _transmit(self, msg):
        assert self._connected, "Cannot transmit before connect"
        metrics.emit_counter('send_counter', 1)
        with self._transmit_send_lock:
            msg.seq_num = self._next_send_seq_num
            self._next_send_seq_num += 1

            if self._streaming_mode:
                self._transmit_queue.put(msg)
                return

            def sender():
                rsp = self._client.Transmit(msg)
                assert rsp.status.code == common_pb.STATUS_SUCCESS, \
                    "Transmit error with code %d."%rsp.status.code
            self._rpc_with_retry(sender, "Bridge transmit failed")
Example #4
0
 def _rpc_with_retry(self, sender, err_log):
     while True:
         with self._client_lock:
             try:
                 return sender()
             except Exception as e:  # pylint: disable=broad-except
                 logging.warning(
                     "%s: %s. Retry in 1s...", err_log, repr(e))
                 metrics.emit_counter('reconnect_counter', 1)
                 self._channel.close()
                 time.sleep(1)
                 self._channel = make_insecure_channel(
                     self._remote_address, ChannelType.REMOTE,
                     options=self._grpc_options,
                     compression=self._compression)
                 self._client = make_ready_client(self._channel)
                 self._check_remote_heartbeat(self._client)
Example #5
0
    def _transmit_handler(self, request):
        assert self._connected, "Cannot transmit before connect"
        metrics.emit_counter('receive_counter', 1)
        with self._transmit_receive_lock:
            logging.debug("Received message seq_num=%d."
                          " Wanted seq_num=%d.",
                          request.seq_num, self._next_receive_seq_num)
            if request.seq_num > self._next_receive_seq_num:
                return tws_pb.TrainerWorkerResponse(
                    status=common_pb.Status(
                        code=common_pb.STATUS_MESSAGE_MISSING),
                    next_seq_num=self._next_receive_seq_num)
            if request.seq_num < self._next_receive_seq_num:
                return tws_pb.TrainerWorkerResponse(
                    status=common_pb.Status(
                        code=common_pb.STATUS_MESSAGE_DUPLICATED),
                    next_seq_num=self._next_receive_seq_num)

            # request.seq_num == self._next_receive_seq_num
            self._next_receive_seq_num += 1

            if request.HasField('start'):
                with self._condition:
                    self._received_data[request.start.iter_id] = {}
            elif request.HasField('commit'):
                pass
            elif request.HasField('data'):
                with self._condition:
                    assert request.data.iter_id in self._received_data
                    self._received_data[
                        request.data.iter_id][
                            request.data.name] = request.data
                    self._condition.notifyAll()
            elif request.HasField('prefetch'):
                for func in self._prefetch_handlers:
                    func(request.prefetch)
            else:
                return tws_pb.TrainerWorkerResponse(
                    status=common_pb.Status(
                        code=common_pb.STATUS_INVALID_REQUEST),
                    next_seq_num=self._next_receive_seq_num)

            return tws_pb.TrainerWorkerResponse(
                next_seq_num=self._next_receive_seq_num)
Example #6
0
    def _client_daemon_fn(self):
        stop_event = threading.Event()
        generator = None
        channel = make_insecure_channel(
            self._remote_address, ChannelType.REMOTE,
            options=self._grpc_options, compression=self._compression)
        client = make_ready_client(channel, stop_event)

        lock = threading.Lock()
        resend_list = collections.deque()

        def shutdown_fn():
            with lock:
                while len(resend_list) > 0 or not self._transmit_queue.empty():
                    logging.debug(
                        "Waiting for resend queue's being cleaned. "
                        "Resend queue size: %d", len(resend_list))
                    lock.release()
                    time.sleep(1)
                    lock.acquire()

            stop_event.set()
            if generator is not None:
                generator.cancel()

        self._client_daemon_shutdown_fn = shutdown_fn

        while not stop_event.is_set():
            try:
                def iterator():
                    with lock:
                        resend_msgs = list(resend_list)
                    for item in resend_msgs:
                        logging.warning("Streaming resend message seq_num=%d",
                                        item.seq_num)
                        metrics.emit_counter("resend_counter", 1)
                        yield item
                    while True:
                        item = self._transmit_queue.get()
                        with lock:
                            resend_list.append(item)
                        logging.debug("Streaming send message seq_num=%d",
                                      item.seq_num)
                        yield item

                generator = client.StreamTransmit(iterator())
                for response in generator:
                    if response.status.code == common_pb.STATUS_SUCCESS:
                        logging.debug("Message with seq_num=%d is "
                            "confirmed", response.next_seq_num-1)
                    elif response.status.code == \
                        common_pb.STATUS_MESSAGE_DUPLICATED:
                        logging.debug("Resent Message with seq_num=%d is "
                            "confirmed", response.next_seq_num)
                    elif response.status.code == \
                        common_pb.STATUS_MESSAGE_MISSING:
                        raise RuntimeError("Message with seq_num=%d is "
                            "missing!" % (response.next_seq_num))
                    else:
                        raise RuntimeError("Trainsmit failed with %d" %
                                           response.status.code)
                    with lock:
                        while resend_list and \
                                resend_list[0].seq_num < response.next_seq_num:
                            resend_list.popleft()
                        min_seq_num_to_resend = resend_list[0].seq_num \
                            if resend_list else "NaN"
                        logging.debug(
                            "Resend queue size: %d, starting from seq_num=%s",
                            len(resend_list), min_seq_num_to_resend)
            except Exception as e:  # pylint: disable=broad-except
                if not stop_event.is_set():
                    logging.warning("Bridge streaming broken: %s.", repr(e))
                    metrics.emit_counter('reconnect_counter', 1)
            finally:
                generator.cancel()
                channel.close()
                time.sleep(1)
                logging.warning(
                    "Restarting streaming: resend queue size: %d, "
                    "starting from seq_num=%s", len(resend_list),
                    resend_list and resend_list[0].seq_num or "NaN")
                channel = make_insecure_channel(
                    self._remote_address, ChannelType.REMOTE,
                    options=self._grpc_options, compression=self._compression)
                client = make_ready_client(channel, stop_event)
                self._check_remote_heartbeat(client)
Example #7
0
    def _client_daemon_fn(self):
        stop_event = threading.Event()
        generator = None
        channel = make_insecure_channel(self._remote_address,
                                        ChannelType.REMOTE,
                                        options=self._grpc_options,
                                        compression=self._compression)
        client = make_ready_client(channel, stop_event)

        def shutdown_fn():
            while self._transmit_queue.size():
                logging.debug(
                    "Waiting for message queue's being cleaned. "
                    "Queue size: %d", self._transmit_queue.size())
                time.sleep(1)

            stop_event.set()
            if generator is not None:
                generator.cancel()

        self._client_daemon_shutdown_fn = shutdown_fn

        while not stop_event.is_set():
            try:

                def iterator():
                    while True:
                        item = self._transmit_queue.get()
                        logging.debug("Streaming send message seq_num=%d",
                                      item.seq_num)
                        yield item

                generator = client.StreamTransmit(iterator())
                for response in generator:
                    if response.status.code == common_pb.STATUS_SUCCESS:
                        self._transmit_queue.confirm(response.next_seq_num)
                        logging.debug(
                            "Message with seq_num=%d is "
                            "confirmed", response.next_seq_num - 1)
                    elif response.status.code == \
                            common_pb.STATUS_MESSAGE_DUPLICATED:
                        self._transmit_queue.confirm(response.next_seq_num)
                        logging.debug(
                            "Resent Message with seq_num=%d is "
                            "confirmed", response.next_seq_num - 1)
                    elif response.status.code == \
                            common_pb.STATUS_MESSAGE_MISSING:
                        self._transmit_queue.resend(response.next_seq_num)
                    else:
                        raise RuntimeError("Trainsmit failed with %d" %
                                           response.status.code)
            except Exception as e:  # pylint: disable=broad-except
                if not stop_event.is_set():
                    logging.warning("Bridge streaming broken: %s.", repr(e))
                    metrics.emit_counter('reconnect_counter', 1)
            finally:
                generator.cancel()
                channel.close()
                time.sleep(1)
                self._transmit_queue.resend(-1)
                channel = make_insecure_channel(self._remote_address,
                                                ChannelType.REMOTE,
                                                options=self._grpc_options,
                                                compression=self._compression)
                client = make_ready_client(channel, stop_event)
                self._check_remote_heartbeat(client)