def _transmit_handler(self, request): assert self._connected, "Cannot transmit before connect" if request.seq_num >= self._next_receive_seq_num: assert request.seq_num == self._next_receive_seq_num, \ "Invalid request" self._next_receive_seq_num += 1 if request.HasField('start'): with self._condition: self._received_data[request.start.iter_id] = {} elif request.HasField('commit'): pass elif request.HasField('data'): with self._condition: assert request.data.iter_id in self._received_data self._received_data[ request.data.iter_id][ request.data.name] = \ tf.make_ndarray(request.data.tensor) self._condition.notifyAll() elif request.HasField('prefetch'): for func in self._prefetch_handlers: func(request.prefetch) else: return tws_pb.TrainerWorkerResponse( status=common_pb.Status( code=common_pb.STATUS_INVALID_REQUEST), next_seq_num=self._next_receive_seq_num) return tws_pb.TrainerWorkerResponse( next_seq_num=self._next_receive_seq_num)
def _transmit_handler(self, request): assert self._connected, "Cannot transmit before connect" with self._transmit_receive_lock: if request.HasField('keepalive'): # keep alive message, do nothing return tws_pb.TrainerWorkerResponse(next_seq_num=-1) logging.debug("Received message seq_num=%d." " Wanted seq_num=%d.", request.seq_num, self._next_receive_seq_num) if request.seq_num > self._next_receive_seq_num: return tws_pb.TrainerWorkerResponse( status=common_pb.Status( code=common_pb.STATUS_MESSAGE_MISSING), next_seq_num=self._next_receive_seq_num) if request.seq_num < self._next_receive_seq_num: return tws_pb.TrainerWorkerResponse( status=common_pb.Status( code=common_pb.STATUS_MESSAGE_DUPLICATED), next_seq_num=self._next_receive_seq_num) # request.seq_num == self._next_receive_seq_num self._next_receive_seq_num += 1 if request.HasField('start'): with self._condition: self._received_data[request.start.iter_id] = {} elif request.HasField('commit'): pass elif request.HasField('data'): with self._condition: assert request.data.iter_id in self._received_data self._received_data[ request.data.iter_id][ request.data.name] = request.data self._condition.notifyAll() elif request.HasField('prefetch'): for func in self._prefetch_handlers: func(request.prefetch) else: return tws_pb.TrainerWorkerResponse( status=common_pb.Status( code=common_pb.STATUS_INVALID_REQUEST), next_seq_num=self._next_receive_seq_num) return tws_pb.TrainerWorkerResponse( next_seq_num=self._next_receive_seq_num)