def _check_token(self, token): if self._token != token: fl_logging.debug( "[Channel] peer unauthorized, got token: '%s', " "want: '%s'", token, self._token) return False return True
def evaluate(self, input_fn): with tf.Graph().as_default() as g, \ g.device(self._cluster_server.device_setter): features, labels = self._get_features_and_labels_from_input_fn( input_fn, tf.estimator.ModeKeys.EVAL) spec, model = self._get_model_spec(features, labels, tf.estimator.ModeKeys.EVAL) # Track the average loss in default eval_metric_ops = spec.eval_metric_ops or {} if model_fn_lib.LOSS_METRIC_KEY not in eval_metric_ops: loss_metric = tf.metrics.mean(spec.loss) eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric # Create the real eval op update_ops, eval_dict = _extract_metric_update_ops(eval_metric_ops) update_ops.extend(model._train_ops) eval_op = tf.group(*update_ops) # Also track the global step if tf.GraphKeys.GLOBAL_STEP in eval_dict: raise ValueError( 'Metric with name `global_step` is not allowed, because ' 'Estimator already defines a default metric with the ' 'same name.') eval_dict[tf.GraphKeys.GLOBAL_STEP] = \ tf.train.get_or_create_global_step() # Prepare hooks all_hooks = [] if spec.evaluation_hooks: all_hooks.extend(spec.evaluation_hooks) final_ops_hook = tf.train.FinalOpsHook(eval_dict) all_hooks.append(final_ops_hook) session_creator = tf.train.WorkerSessionCreator( master=self._cluster_server.target, config=self._cluster_server.cluster_config) # Evaluate over dataset self._bridge.connect() with tf.train.MonitoredSession(session_creator=session_creator, hooks=all_hooks) as sess: while not sess.should_stop(): start_time = time.time() self._bridge.start() sess.run(eval_op) self._bridge.commit() use_time = time.time() - start_time fl_logging.debug("after session run. time: %f sec", use_time) self._bridge.terminate() # Print result fl_logging.info('Metrics for evaluate: %s', _dict_to_str(final_ops_hook.final_ops_values)) return self
def load_data_block(self, count, block_id): req = tws2_pb.LoadDataBlockRequest(count=count, block_id=block_id) fl_logging.debug("[Bridge] sending DataBlock with id %s", block_id) resp = self._client.LoadDataBlock(req) if resp.status.code == common_pb.STATUS_SUCCESS: fl_logging.info("[Bridge] remote succeeded to load data block %s", block_id) return True fl_logging.warning( "[Bridge] remoted failed to load data block %s. " "code: %d", block_id, resp.status.code) return False
def start(self): with self._condition: self._assert_iter_committed() self._current_iter_id = self._next_iter_id self._next_iter_id += 1 self._iter_started_at = time.time() self._transmit( tws2_pb.TransmitRequest( start=tws2_pb.TransmitRequest.StartMessage( iter_id=self._current_iter_id))) fl_logging.debug("[Bridge] send start iter_id: %d", self._current_iter_id)
def _send(self, name, tensor=None, any_data=None): with self._condition: self._assert_iter_started() self._transmit( tws2_pb.TransmitRequest( data=tws2_pb.TransmitRequest.DataMessage( iter_id=self._current_iter_id, name=name, tensor=tensor, any_data=any_data, ))) fl_logging.debug("[Bridge] Data: send iter_id: %d, name: %s", self._current_iter_id, name)
def _receive(self, name): start_time = time.time() alert_count = 0 with self._condition: self._assert_iter_started() iter_id = self._current_iter_id while (iter_id not in self._received_data or name not in self._received_data[iter_id]): self._assert_iter_started() if iter_id != self._current_iter_id: raise RuntimeError( "[Bridge] iter change while waiting receive data, " "iter_id: {}, name: {}".format(iter_id, name)) if self._peer_commit_iter_id is not None \ and iter_id <= self._peer_commit_iter_id: raise RuntimeError( "[Bridge] peer committed without sending data " "iter_id: {}, name: {}".format(iter_id, name)) if self._peer_terminated: raise RuntimeError( "[Bridge] peer terminated without sending data " "iter_id: {}, name: {}".format(iter_id, name)) duration = time.time() - start_time if duration >= (alert_count + 1) * self._waiting_alert_timeout: alert_count += 1 fl_logging.warning( "[Bridge] Data: waiting to receive " "iter_id: %d, name: %s timeout. duration: %f sec", iter_id, name, duration) wait_timeout = self._waiting_alert_timeout - \ (duration % self._waiting_alert_timeout) self._condition.wait(wait_timeout) data = self._received_data[iter_id][name] duration = time.time() - start_time _gctx.stats_client.timing("trainer.bridge.receive_timing", duration * 1000, {"bridge_receive_name": name}) fl_logging.debug( "[Bridge] Data: received iter_id: %d, name: %s " "after %f sec", iter_id, name, duration) return data
def commit(self): with self._condition: self._assert_iter_started() self._transmit( tws2_pb.TransmitRequest( commit=tws2_pb.TransmitRequest.CommitMessage( iter_id=self._current_iter_id))) fl_logging.debug("[Bridge] send commit iter_id: %d", self._current_iter_id) # delete committed data if self._current_iter_id in self._received_data: del self._received_data[self._current_iter_id] iter_id = self._current_iter_id duration = (time.time() - self._iter_started_at) * 1000 self._current_iter_id = None with _gctx.stats_client.pipeline() as pipe: pipe.gauge("trainer.bridge.iterator_step", iter_id) pipe.timing("trainer.bridge.iterator_timing", duration)
def request_data_block(self, block_id): request = tm_pb.DataBlockRequest(worker_rank=self._worker_rank, block_id=block_id) response = _grpc_with_retry( lambda: self._client.RequestDataBlock(request)) if response.status.code == common_pb.StatusCode.STATUS_SUCCESS: fl_logging.debug( "succeeded to get datablock, id:%s, data_path: %s", response.block_id, response.data_path) return response if response.status.code == \ common_pb.StatusCode.STATUS_INVALID_DATA_BLOCK: fl_logging.error("invalid data block id: %s", request.block_id) return None if response.status.code == common_pb.StatusCode.STATUS_DATA_FINISHED: fl_logging.info("data block finished") return None raise RuntimeError("RequestDataBlock error, code: %s, msg: %s"% \ (common_pb.StatusCode.Name(response.status.code), response.status.error_message))
def request_iterator(): fl_logging.debug("[Bridge] stream transmitting") while True: with self._stream_condition: if len(self._stream_queue) == 0: start = time.time() while len(self._stream_queue) == 0: duration = time.time() - start if self._stream_terminated: return if duration >= IDLE_TIMEOUT: fl_logging.debug( "[Bridge] stream transmit " " closed by idle timeout: %f sec", IDLE_TIMEOUT) return self._stream_condition.wait(IDLE_TIMEOUT - duration) msg = self._stream_queue.popleft() self._stream_condition.notify_all() yield msg
def train(self, input_fn): with tf.Graph().as_default() as g, \ g.device(self._cluster_server.device_setter): features, labels = self._get_features_and_labels_from_input_fn( input_fn, tf.estimator.ModeKeys.TRAIN) spec, _ = self._get_model_spec( features, labels, tf.estimator.ModeKeys.TRAIN) hooks = [] # stats hooks.append(TraceStatsHook( every_secs=30, stats_client=_gctx.stats_client)) # user define chief hook if spec.training_chief_hooks and self._is_chief: hooks.extend(spec.training_chief_hooks) if spec.training_hooks: hooks.extend(spec.training_hooks) session_creator = tf.train.WorkerSessionCreator( master=self._cluster_server.target, config=self._cluster_server.cluster_config) self._bridge.connect() with tf.train.MonitoredSession( session_creator=session_creator, hooks=hooks) as sess: while not sess.should_stop(): start_time = time.time() self._bridge.start() sess.run(spec.train_op, feed_dict={}) self._bridge.commit() use_time = time.time() - start_time fl_logging.debug("after session run. time: %f sec", use_time) self._bridge.terminate() return self
def _stream_transmit_fn(self): IDLE_TIMEOUT = 30 fl_logging.debug("[Bridge] stream transmit started") def request_iterator(): fl_logging.debug("[Bridge] stream transmitting") while True: with self._stream_condition: if len(self._stream_queue) == 0: start = time.time() while len(self._stream_queue) == 0: duration = time.time() - start if self._stream_terminated: return if duration >= IDLE_TIMEOUT: fl_logging.debug( "[Bridge] stream transmit " " closed by idle timeout: %f sec", IDLE_TIMEOUT) return self._stream_condition.wait(IDLE_TIMEOUT - duration) msg = self._stream_queue.popleft() self._stream_condition.notify_all() yield msg while True: with self._stream_condition: while len(self._stream_queue) == 0: if self._stream_terminated: fl_logging.debug("[Bridge] stream transmit closed") return self._stream_condition.wait() response_iterator = \ self._client.Transmit(request_iterator()) for _ in response_iterator: pass
def _transmit_handler(self, request): with self._condition: if request.HasField("start"): if self._peer_commit_iter_id is not None \ and request.start.iter_id <= self._peer_commit_iter_id: fl_logging.warning( "[Bridge] received peer start iter_id: %d " "which has been committed. " "maybe caused by resend.(peer_commit_iter_id: %d)", request.start.iter_id, self._peer_commit_iter_id) elif self._peer_start_iter_id is not None: fl_logging.warning( "[Bridge] received repeated peer start iter_id: %d. " "maybe caused by resend.(peer_start_iter_id: %d)", request.start.iter_id, self._peer_start_iter_id) else: fl_logging.debug( "[Bridge] received peer start iter_id: %d", request.start.iter_id) self._peer_start_iter_id = request.start.iter_id self._condition.notify_all() elif request.HasField("data"): if self._peer_start_iter_id is None: fl_logging.warning( "[Bridge] received data iter_id: %d without start. " "maybe caused by resend.", request.data.iter_id) elif self._peer_start_iter_id != request.data.iter_id: fl_logging.warning( "[Bridge] received data iter_id: %d no match start. " "maybe caused by resend.(peer_start_iter_id: %d)", request.data.iter_id, self._peer_start_iter_id) else: iter_id = self._current_iter_id \ if self._current_iter_id is not None \ else self._next_iter_id if request.data.iter_id < iter_id: fl_logging.debug( "[Bridge] received data iter_id: %d, " "name: %s, ignored by our commit." "(current_iter_id: %s, next_iter_id: %d)", request.data.iter_id, request.data.name, self._current_iter_id, self._next_iter_id) else: fl_logging.debug( "[Bridge] received data iter_id: %d, " "name: %s", request.data.iter_id, request.data.name) self._received_data[ \ request.data.iter_id][ \ request.data.name] = request.data self._condition.notify_all() elif request.HasField("commit"): if self._peer_commit_iter_id is not None \ and request.commit.iter_id <= self._peer_commit_iter_id: fl_logging.warning( "[Bridge] receive repeated peer commit iter_id: %d. " "maybe caused by resend.(peer_commit_iter_id: %d)", request.commit.iter_id, self._peer_commit_iter_id) elif self._peer_start_iter_id is None: fl_logging.error( "[Bridge] receive peer commit iter_id: %d " "without start", request.commit.iter_id) # return error? elif request.commit.iter_id != self._peer_start_iter_id: fl_logging.error( "[Bridge] receive peer commit iter_id: %s " "no match start.(peer_start_iter_id: %d)", request.commit.iter_id, self._peer_start_iter_id) # return error? else: fl_logging.debug( "[Bridge] receive peer commit iter_id: %d", request.commit.iter_id) self._peer_start_iter_id = None self._peer_commit_iter_id = request.commit.iter_id self._condition.notify_all() return tws2_pb.TransmitResponse(status=common_pb.Status( code=common_pb.STATUS_SUCCESS))
def _state_fn(self): fl_logging.debug("[Channel] thread _state_fn start") self._server.start() #self._channel.subscribe(self._channel_callback) self._lock.acquire() while True: now = time.time() saved_state = self._state wait_timeout = 10 self._stats_client.gauge("channel.status", self._state.value) if self._state in (Channel.State.DONE, Channel.State.ERROR): break # check disconnected if self._state not in Channel._CONNECTING_STATES: if now >= self._heartbeat_timeout_at: self._emit_error( ChannelError( "disconnected by heartbeat timeout: {}s".format( self._heartbeat_timeout))) continue wait_timeout = min(wait_timeout, self._heartbeat_timeout_at - now) # check peer disconnected if self._state not in Channel._PEER_UNCONNECTED_STATES: if now >= self._peer_heartbeat_timeout_at: self._emit_error( ChannelError( "peer disconnected by heartbeat timeout: {}s". format(self._heartbeat_timeout))) continue wait_timeout = min(wait_timeout, self._peer_heartbeat_timeout_at - now) if now >= self._next_retry_at: self._next_retry_at = 0 if self._state in Channel._CONNECTING_STATES: if self._call_locked(channel_pb2.CallType.CONNECT): self._emit_event(Channel.Event.CONNECTED) self._refresh_heartbeat_timeout() else: self._next_retry_at = time.time( ) + self._retry_interval continue if self._state in Channel._CLOSING_STATES: if self._call_locked(channel_pb2.CallType.CLOSE): self._emit_event(Channel.Event.CLOSED) self._refresh_heartbeat_timeout() else: self._next_retry_at = 0 # fast retry continue if now >= self._next_heartbeat_at: if self._call_locked(channel_pb2.CallType.HEARTBEAT): fl_logging.debug("[Channel] call heartbeat OK") self._refresh_heartbeat_timeout() else: fl_logging.warning("[Channel] call heartbeat failed") interval = min(self._heartbeat_interval, self._retry_interval) self._next_retry_at = time.time() + interval continue wait_timeout = min(wait_timeout, self._next_heartbeat_at - now) else: wait_timeout = min(wait_timeout, self._next_retry_at - now) if saved_state != self._state: continue self._condition.wait(wait_timeout) # done self._lock.release() self._channel.close() time_wait = 2 * self._retry_interval + self._peer_closed_at - time.time( ) if time_wait > 0: fl_logging.info( "[Channel] wait %0.2f sec " "for reducing peer close failed", time_wait) time.sleep(time_wait) self._server.stop(10) self._server.wait_for_termination() self._ready_event.set() self._closed_event.set() fl_logging.debug("[Channel] thread _state_fn stop")