def _create_client(): """ Enable stats from env """ global _enabled #pylint: disable=global-statement enabled_str = os.getenv("FL_STATS_ENABLED") url = os.getenv("FL_STATS_URL") try: _enabled = strtobool(enabled_str) if enabled_str else bool(url) except ValueError: fl_logging.warning( "[Stats] env FL_STATS_ENABLED=%s " "is invalid truth value. Disable stats ", enabled_str) _enabled = False if not _enabled: fl_logging.info("[Stats] stats not enabled") return NoneClient() if not url: fl_logging.warning( "[Stats] FL_STATS_URL not found, redirect to stderr") url = "stderr://" try: return Client(url) except Exception as e: #pylint: disable=broad-except fl_logging.error("[Stats] new client error: %s, redirect to stderr", e) return Client("stderr://")
def WorkerRegister(self, request, context): with self._lock: # for compatibility, more information see: # protocal/fedlearner/common/trainer_master_service.proto if self._worker0_cluster_def is None and request.worker_rank == 0: self._worker0_cluster_def = request.cluster_def if self._status in (tm_pb.MasterStatus.WORKER_COMPLETED, tm_pb.MasterStatus.COMPLETED): return tm_pb.WorkerRegisterResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_DATA_FINISHED)) if self._status != tm_pb.MasterStatus.RUNNING: return tm_pb.WorkerRegisterResponse( status=common_pb.Status( code=common_pb.StatusCode. \ STATUS_WAIT_FOR_SYNCING_CHECKPOINT )) if request.worker_rank in self._running_workers: fl_logging.warning("worker_%d:%s repeat registration", request.worker_rank, request.hostname) else: fl_logging.info("worker_%d:%s registration", request.worker_rank, request.hostname) self._running_workers.add(request.worker_rank) if request.worker_rank in self._completed_workers: self._completed_workers.remove(request.worker_rank) return tm_pb.WorkerRegisterResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_SUCCESS))
def evaluate(self, input_fn): with tf.Graph().as_default() as g, \ g.device(self._cluster_server.device_setter): features, labels = self._get_features_and_labels_from_input_fn( input_fn, tf.estimator.ModeKeys.EVAL) spec, model = self._get_model_spec(features, labels, tf.estimator.ModeKeys.EVAL) # Track the average loss in default eval_metric_ops = spec.eval_metric_ops or {} if model_fn_lib.LOSS_METRIC_KEY not in eval_metric_ops: loss_metric = tf.metrics.mean(spec.loss) eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric # Create the real eval op update_ops, eval_dict = _extract_metric_update_ops(eval_metric_ops) update_ops.extend(model._train_ops) eval_op = tf.group(*update_ops) # Also track the global step if tf.GraphKeys.GLOBAL_STEP in eval_dict: raise ValueError( 'Metric with name `global_step` is not allowed, because ' 'Estimator already defines a default metric with the ' 'same name.') eval_dict[tf.GraphKeys.GLOBAL_STEP] = \ tf.train.get_or_create_global_step() # Prepare hooks all_hooks = [] if spec.evaluation_hooks: all_hooks.extend(spec.evaluation_hooks) final_ops_hook = tf.train.FinalOpsHook(eval_dict) all_hooks.append(final_ops_hook) session_creator = tf.train.WorkerSessionCreator( master=self._cluster_server.target, config=self._cluster_server.cluster_config) # Evaluate over dataset self._bridge.connect() with tf.train.MonitoredSession(session_creator=session_creator, hooks=all_hooks) as sess: while not sess.should_stop(): start_time = time.time() self._bridge.start() sess.run(eval_op) self._bridge.commit() use_time = time.time() - start_time fl_logging.debug("after session run. time: %f sec", use_time) self._bridge.terminate() # Print result fl_logging.info('Metrics for evaluate: %s', _dict_to_str(final_ops_hook.final_ops_values)) return self
def before_checkpoint_save(self, session, global_step_value): data = self._visitor.dump() fl_logging.info( "DataVisitor save checkpoint for global step %d, " "size: %d", global_step_value, len(data)) session.run( self._save_op, {self._ckpt_plhd: data}, )
def wait_master_complete(self): request = tm_pb.IsCompletedRequest() while True: response = _grpc_with_retry( lambda: self._client.IsCompleted(request)) if response.completed: fl_logging.info("master completed") return fl_logging.info("waiting master complete...") time.sleep(2)
def _run_grpc_server(self, address): self._grpc_server = grpc.server( futures.ThreadPoolExecutor( max_workers=8, thread_name_prefix="TrainerMasterServerThreadPoolExecutor")) tm_grpc.add_TrainerMasterServiceServicer_to_server( self, self._grpc_server) self._grpc_server.add_insecure_port(address) self._grpc_server.start() fl_logging.info('Trainer Master Server start on address: %s', address)
def _run(self): fl_logging.info("create estimator") estimator = self._create_estimator() fl_logging.info("start session_run") self._session_run(estimator) fl_logging.info("session_run done") fl_logging.info("start export_model") self._export_model(estimator) fl_logging.info("export_model done") self._transfer_status(tm_pb.MasterStatus.WORKER_COMPLETED, tm_pb.MasterStatus.COMPLETED)
def load_data_block(self, count, block_id): req = tws2_pb.LoadDataBlockRequest(count=count, block_id=block_id) fl_logging.debug("[Bridge] sending DataBlock with id %s", block_id) resp = self._client.LoadDataBlock(req) if resp.status.code == common_pb.STATUS_SUCCESS: fl_logging.info("[Bridge] remote succeeded to load data block %s", block_id) return True fl_logging.warning( "[Bridge] remoted failed to load data block %s. " "code: %d", block_id, resp.status.code) return False
def _data_block_handler(self, request): assert self._data_block_handler_fn is not None, \ "[Bridge] receive DataBlockMessage but no handler registered." if self._data_block_handler_fn(request): fl_logging.info("[Bridge] succeeded to load data block %s", request.block_id) return tws2_pb.LoadDataBlockResponse(status=common_pb.Status( code=common_pb.STATUS_SUCCESS)) fl_logging.info("[Bridge] failed to load data block %s", request.block_id) return tws2_pb.LoadDataBlockResponse(status=common_pb.Status( code=common_pb.STATUS_INVALID_DATA_BLOCK))
def _transfer_status(self, frm, to): if self._status != frm: raise RuntimeError( "Trainer Master status transfer failed, " "want from %s to %s, but current status: %s"% \ (tm_pb.MasterStatus.Name(frm), tm_pb.MasterStatus.Name(to), tm_pb.MasterStatus.Name(self._status)) ) self._status = to fl_logging.info("Trainer Master status transfer, from %s to %s", tm_pb.MasterStatus.Name(frm), tm_pb.MasterStatus.Name(to))
def _supervise_fn(self): check_handlers = [] if self._supervise_iteration_timeout > 0: fl_logging.info("enable supervise iteartion timeout: %f", self._supervise_iteration_timeout) check_handlers.append(self._check_iteration_timeout) if len(check_handlers) == 0: return while True: with self._condition: if self._terminated: return for handler in check_handlers: handler() time.sleep(self._supervise_interval)
def _data_block_handler(self, msg): if self._count > msg.count: fl_logging.warn('DataBlock: ignore repeated datablock "%s" at %d', msg.block_id, msg.count) return True fl_logging.info('DataBlock: recv "%s" at %d', msg.block_id, msg.count) assert self._count == msg.count if not msg.block_id: block = None else: block = self._trainer_master.request_data_block(msg.block_id) if block is None: return False self._count += 1 self._block_queue.put(block) return True
def _create_tf_server(self, cluster_spec): self._tf_config = tf.ConfigProto() self._tf_config.inter_op_parallelism_threads = 64 self._tf_config.intra_op_parallelism_threads = 64 self._tf_config.experimental \ .share_session_state_in_clusterspec_propagation = True self._tf_config.rpc_options.compression_algorithm = "gzip" self._tf_config.rpc_options.cache_rpc_response = True self._tf_config.rpc_options.disable_session_connection_sharing = True try: address = cluster_spec.task_address( self._job_name, self._task_index) self._tf_server = \ tf.distribute.Server({"server": { self._task_index: address} }, protocol="grpc", config=self._tf_config) self._tf_target = "grpc://" + address except ValueError: self._tf_server = \ tf.distribute.Server({"server": {self._task_index: "localhost:0"} }, protocol="grpc", config=self._tf_config) self._tf_target = self._tf_server.target # modify cluster_spec cluster_dict = dict() cluster_dict[self._job_name] = { self._task_index: self._tf_target[len("grpc://"):] } for job_name in cluster_spec.jobs: if job_name == self._job_name: continue if job_name in self._extra_reserve_jobs: cluster_dict[job_name] = cluster_spec.job_tasks(job_name) self._tf_cluster_spec = tf.train.ClusterSpec(cluster_dict) self._tf_config.cluster_def.CopyFrom( self._tf_cluster_spec.as_cluster_def()) fl_logging.info("cluster server target: %s\nconfig: \n%s", self._tf_target, self._tf_config)
def _emit_event(self, event, error=None): with self._lock: next_state = self._next_state(self._state, event) if self._state != next_state: fl_logging.info( "[Channel] state changed from %s to %s, " "event: %s", self._state.name, next_state.name, event.name) self._state = next_state if self._state == Channel.State.ERROR: assert error is not None self._error_event.set() fl_logging.error("[Channel] occur error: %s", str(error)) self._error = error elif self._state == Channel.State.READY: self._ready_event.set() self._event_callback(event) self._condition.notify_all()
def _request_data_block(self, request): data_block = self._data_visitor.get_datablock_by_id(request.block_id) if data_block: fl_logging.info("allocated worker_%d with block: %s", request.worker_rank, data_block.id) response = tm_pb.DataBlockResponse( status=common_pb.Status( code=common_pb.StatusCode.STATUS_SUCCESS), block_id=data_block.id, data_path=data_block.data_path, ) else: fl_logging.error("invalid data block id: %s", request.block_id) response = tm_pb.DataBlockResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_INVALID_DATA_BLOCK, error_message="invalid data block")) return response
def _trigger_fn(self, global_step): now = time.time() if self._last_global_step >= 0: speed = (global_step-self._last_global_step) \ / (now-self._last_trigger_time) total_datablock = self._data_visitor.datablock_size fl_logging.info( "global_step: %d, speed: %0.2f step/sec, " "datablock size: %d, " "worker: %d/%d(running/completed)", global_step, speed, total_datablock, len(self._running_workers), len(self._completed_workers)) with _gctx.stats_client.pipeline() as pipe: pipe.gauge("trainer.global_step", global_step) pipe.gauge("trainer.datablock_total", total_datablock) pipe.gauge("trainer.speed", speed) self._last_trigger_time = now self._last_global_step = global_step
def WorkerComplete(self, request, context): with self._lock: if request.worker_rank not in self._running_workers: return tm_pb.WorkerRegisterResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_INVALID_REQUEST, error_message="unregistered worker")) fl_logging.info("worker_%d completed", request.worker_rank) self._completed_workers.add(request.worker_rank) if request.worker_rank == 0: self._worker0_terminated_at = request.timestamp if len(self._running_workers) == len(self._completed_workers) \ and 0 in self._running_workers: # worker 0 completed and all datablock has finished self._transfer_status(tm_pb.MasterStatus.RUNNING, tm_pb.MasterStatus.WORKER_COMPLETED) return tm_pb.WorkerCompleteResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_SUCCESS))
def request_data_block(self, block_id): request = tm_pb.DataBlockRequest(worker_rank=self._worker_rank, block_id=block_id) response = _grpc_with_retry( lambda: self._client.RequestDataBlock(request)) if response.status.code == common_pb.StatusCode.STATUS_SUCCESS: fl_logging.debug( "succeeded to get datablock, id:%s, data_path: %s", response.block_id, response.data_path) return response if response.status.code == \ common_pb.StatusCode.STATUS_INVALID_DATA_BLOCK: fl_logging.error("invalid data block id: %s", request.block_id) return None if response.status.code == common_pb.StatusCode.STATUS_DATA_FINISHED: fl_logging.info("data block finished") return None raise RuntimeError("RequestDataBlock error, code: %s, msg: %s"% \ (common_pb.StatusCode.Name(response.status.code), response.status.error_message))
def worker_register(self, cluster_def=None): request = tm_pb.WorkerRegisterRequest(worker_rank=self._worker_rank, hostname=os.uname().nodename, cluster_def=cluster_def) while True: response = _grpc_with_retry( lambda: self._client.WorkerRegister(request)) if response.status.code == common_pb.StatusCode.STATUS_SUCCESS: return True if response.status.code == \ common_pb.StatusCode.STATUS_WAIT_FOR_SYNCING_CHECKPOINT: fl_logging.info("waiting master ready...") time.sleep(1) continue if response.status.code == \ common_pb.StatusCode.STATUS_DATA_FINISHED: fl_logging.info("master completed, ignore worker register") return False raise RuntimeError("WorkerRegister error, code: %s, msg: %s"% \ (common_pb.StatusCode.Name(response.status.code), response.status.error_message))
def run_forever(self, listen_port=None): with self._lock: self._transfer_status(tm_pb.MasterStatus.CREATED, tm_pb.MasterStatus.INITIALING) if listen_port: self._run_grpc_server(listen_port) while self._cluster_server is None: # waiting receive cluster_def from worker0 with self._lock: if self._worker0_cluster_def: fl_logging.info("received worker_0 cluster_def: %s", self._worker0_cluster_def) self._cluster_server = ClusterServer( tf.train.ClusterSpec(self._worker0_cluster_def), "master") break fl_logging.info("still waiting receive cluster_def from worker_0") time.sleep(2) self._run() sig = signal.sigwait([signal.SIGHUP, signal.SIGINT, signal.SIGTERM]) fl_logging.info("Server shutdown by signal: %s", signal.Signals(sig).name)
def _request_data_block(self, request): try: data_block = next(self._data_visitor) except StopIteration: data_block = None response = tm_pb.DataBlockResponse() if data_block: fl_logging.info("allocated worker_%d with block: %s", request.worker_rank, data_block.id) response = tm_pb.DataBlockResponse( status=common_pb.Status( code=common_pb.StatusCode.STATUS_SUCCESS), block_id=data_block.id, data_path=data_block.data_path, ) else: response = tm_pb.DataBlockResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_DATA_FINISHED, error_message="data block finished")) return response
def _state_fn(self): fl_logging.debug("[Channel] thread _state_fn start") self._server.start() #self._channel.subscribe(self._channel_callback) self._lock.acquire() while True: now = time.time() saved_state = self._state wait_timeout = 10 self._stats_client.gauge("channel.status", self._state.value) if self._state in (Channel.State.DONE, Channel.State.ERROR): break # check disconnected if self._state not in Channel._CONNECTING_STATES: if now >= self._heartbeat_timeout_at: self._emit_error( ChannelError( "disconnected by heartbeat timeout: {}s".format( self._heartbeat_timeout))) continue wait_timeout = min(wait_timeout, self._heartbeat_timeout_at - now) # check peer disconnected if self._state not in Channel._PEER_UNCONNECTED_STATES: if now >= self._peer_heartbeat_timeout_at: self._emit_error( ChannelError( "peer disconnected by heartbeat timeout: {}s". format(self._heartbeat_timeout))) continue wait_timeout = min(wait_timeout, self._peer_heartbeat_timeout_at - now) if now >= self._next_retry_at: self._next_retry_at = 0 if self._state in Channel._CONNECTING_STATES: if self._call_locked(channel_pb2.CallType.CONNECT): self._emit_event(Channel.Event.CONNECTED) self._refresh_heartbeat_timeout() else: self._next_retry_at = time.time( ) + self._retry_interval continue if self._state in Channel._CLOSING_STATES: if self._call_locked(channel_pb2.CallType.CLOSE): self._emit_event(Channel.Event.CLOSED) self._refresh_heartbeat_timeout() else: self._next_retry_at = 0 # fast retry continue if now >= self._next_heartbeat_at: if self._call_locked(channel_pb2.CallType.HEARTBEAT): fl_logging.debug("[Channel] call heartbeat OK") self._refresh_heartbeat_timeout() else: fl_logging.warning("[Channel] call heartbeat failed") interval = min(self._heartbeat_interval, self._retry_interval) self._next_retry_at = time.time() + interval continue wait_timeout = min(wait_timeout, self._next_heartbeat_at - now) else: wait_timeout = min(wait_timeout, self._next_retry_at - now) if saved_state != self._state: continue self._condition.wait(wait_timeout) # done self._lock.release() self._channel.close() time_wait = 2 * self._retry_interval + self._peer_closed_at - time.time( ) if time_wait > 0: fl_logging.info( "[Channel] wait %0.2f sec " "for reducing peer close failed", time_wait) time.sleep(time_wait) self._server.stop(10) self._server.wait_for_termination() self._ready_event.set() self._closed_event.set() fl_logging.debug("[Channel] thread _state_fn stop")