def send_proto(self, iter_id, name, proto): any_proto = google.protobuf.any_pb2.Any() any_proto.Pack(proto) msg = tws_pb.TrainerWorkerMessage(data=tws_pb.DataMessage( iter_id=iter_id, name=name, any_data=any_proto)) self._transmit(msg) logging.debug('Data: send protobuf %s for iter %d. seq_num=%d.', name, iter_id, msg.seq_num)
def start(self, iter_id): assert self._current_iter_id is None, "Last iter not finished" self._current_iter_id = iter_id msg = tws_pb.TrainerWorkerMessage(start=tws_pb.StartMessage( iter_id=iter_id)) self._transmit(msg) logging.debug("Starting iter %d", iter_id)
def commit(self): assert self._current_iter_id is not None, "Not started yet" with self._condition: last_iter_id = self._current_iter_id self._current_iter_id = None if last_iter_id in self._received_data: del self._received_data[last_iter_id] msg = tws_pb.TrainerWorkerMessage(commit=tws_pb.CommitMessage( iter_id=last_iter_id)) self._transmit(msg) logging.debug("iter %d committed", last_iter_id)
def _keepalive_daemon_fn(self): stop_event = threading.Event() def shutdown_fn(): stop_event.set() return self._keepalive_daemon_shutdown_fn = shutdown_fn assert self._streaming_mode and self._connected while not stop_event.is_set(): with self._transmit_send_lock: msg = tws_pb.TrainerWorkerMessage( seq_num=-1, # not used for keep alive message keepalive=tws_pb.KeepAliveMessage() ) self._transmit_queue.put(msg) time.sleep(1)
def send(self, iter_id, name, x): msg = tws_pb.TrainerWorkerMessage(data=tws_pb.DataMessage( iter_id=iter_id, name=name, tensor=tf.make_tensor_proto(x))) self._transmit(msg) logging.debug('Data: send %s for iter %d. seq_num=%d.', name, iter_id, msg.seq_num)
def prefetch(self, iter_id, sample_ids): msg = tws_pb.TrainerWorkerMessage(prefetch=tws_pb.PrefetchMessage( iter_id=iter_id, sample_ids=sample_ids)) self._transmit(msg)
def fake_start_message(seq_num, iter_id): return tws_pb.TrainerWorkerMessage( seq_num=seq_num, start=tws_pb.StartMessage(iter_id=iter_id))