Exemplo n.º 1
0
    def _transmit_handler(self, request):
        assert self._connected, "Cannot transmit before connect"
        if request.seq_num >= self._next_receive_seq_num:
            assert request.seq_num == self._next_receive_seq_num, \
                "Invalid request"
            self._next_receive_seq_num += 1

            if request.HasField('start'):
                with self._condition:
                    self._received_data[request.start.iter_id] = {}
            elif request.HasField('commit'):
                pass
            elif request.HasField('data'):
                with self._condition:
                    assert request.data.iter_id in self._received_data
                    self._received_data[
                        request.data.iter_id][
                            request.data.name] = \
                                tf.make_ndarray(request.data.tensor)
                    self._condition.notifyAll()
            elif request.HasField('prefetch'):
                for func in self._prefetch_handlers:
                    func(request.prefetch)
            else:
                return tws_pb.TrainerWorkerResponse(
                    status=common_pb.Status(
                        code=common_pb.STATUS_INVALID_REQUEST),
                    next_seq_num=self._next_receive_seq_num)

        return tws_pb.TrainerWorkerResponse(
            next_seq_num=self._next_receive_seq_num)
Exemplo n.º 2
0
    def _transmit_handler(self, request):
        assert self._connected, "Cannot transmit before connect"
        with self._transmit_receive_lock:
            if request.HasField('keepalive'):
                # keep alive message, do nothing
                return tws_pb.TrainerWorkerResponse(next_seq_num=-1)

            logging.debug("Received message seq_num=%d."
                          " Wanted seq_num=%d.",
                          request.seq_num, self._next_receive_seq_num)
            if request.seq_num > self._next_receive_seq_num:
                return tws_pb.TrainerWorkerResponse(
                    status=common_pb.Status(
                        code=common_pb.STATUS_MESSAGE_MISSING),
                    next_seq_num=self._next_receive_seq_num)
            if request.seq_num < self._next_receive_seq_num:
                return tws_pb.TrainerWorkerResponse(
                    status=common_pb.Status(
                        code=common_pb.STATUS_MESSAGE_DUPLICATED),
                    next_seq_num=self._next_receive_seq_num)

            # request.seq_num == self._next_receive_seq_num
            self._next_receive_seq_num += 1

            if request.HasField('start'):
                with self._condition:
                    self._received_data[request.start.iter_id] = {}
            elif request.HasField('commit'):
                pass
            elif request.HasField('data'):
                with self._condition:
                    assert request.data.iter_id in self._received_data
                    self._received_data[
                        request.data.iter_id][
                            request.data.name] = request.data
                    self._condition.notifyAll()
            elif request.HasField('prefetch'):
                for func in self._prefetch_handlers:
                    func(request.prefetch)
            else:
                return tws_pb.TrainerWorkerResponse(
                    status=common_pb.Status(
                        code=common_pb.STATUS_INVALID_REQUEST),
                    next_seq_num=self._next_receive_seq_num)

            return tws_pb.TrainerWorkerResponse(
                next_seq_num=self._next_receive_seq_num)