Ejemplo n.º 1
0
 def send_proto(self, iter_id, name, proto):
     any_proto = google.protobuf.any_pb2.Any()
     any_proto.Pack(proto)
     msg = tws_pb.TrainerWorkerMessage(data=tws_pb.DataMessage(
         iter_id=iter_id, name=name, any_data=any_proto))
     self._transmit(msg)
     logging.debug('Data: send protobuf %s for iter %d. seq_num=%d.', name,
                   iter_id, msg.seq_num)
Ejemplo n.º 2
0
    def start(self, iter_id):
        assert self._current_iter_id is None, "Last iter not finished"
        self._current_iter_id = iter_id

        msg = tws_pb.TrainerWorkerMessage(start=tws_pb.StartMessage(
            iter_id=iter_id))
        self._transmit(msg)
        logging.debug("Starting iter %d", iter_id)
Ejemplo n.º 3
0
    def commit(self):
        assert self._current_iter_id is not None, "Not started yet"
        with self._condition:
            last_iter_id = self._current_iter_id
            self._current_iter_id = None
            if last_iter_id in self._received_data:
                del self._received_data[last_iter_id]

        msg = tws_pb.TrainerWorkerMessage(commit=tws_pb.CommitMessage(
            iter_id=last_iter_id))
        self._transmit(msg)
        logging.debug("iter %d committed", last_iter_id)
Ejemplo n.º 4
0
 def _keepalive_daemon_fn(self):
     stop_event = threading.Event()
     def shutdown_fn():
         stop_event.set()
         return
     self._keepalive_daemon_shutdown_fn = shutdown_fn
     assert self._streaming_mode and self._connected
     while not stop_event.is_set():
         with self._transmit_send_lock:
             msg = tws_pb.TrainerWorkerMessage(
                 seq_num=-1, # not used for keep alive message
                 keepalive=tws_pb.KeepAliveMessage()
             )
             self._transmit_queue.put(msg)
         time.sleep(1)
Ejemplo n.º 5
0
 def send(self, iter_id, name, x):
     msg = tws_pb.TrainerWorkerMessage(data=tws_pb.DataMessage(
         iter_id=iter_id, name=name, tensor=tf.make_tensor_proto(x)))
     self._transmit(msg)
     logging.debug('Data: send %s for iter %d. seq_num=%d.', name, iter_id,
                   msg.seq_num)
Ejemplo n.º 6
0
 def prefetch(self, iter_id, sample_ids):
     msg = tws_pb.TrainerWorkerMessage(prefetch=tws_pb.PrefetchMessage(
         iter_id=iter_id, sample_ids=sample_ids))
     self._transmit(msg)
Ejemplo n.º 7
0
def fake_start_message(seq_num, iter_id):
    return tws_pb.TrainerWorkerMessage(
        seq_num=seq_num, start=tws_pb.StartMessage(iter_id=iter_id))