def _create_client(): """ Enable stats from env """ global _enabled #pylint: disable=global-statement enabled_str = os.getenv("FL_STATS_ENABLED") url = os.getenv("FL_STATS_URL") try: _enabled = strtobool(enabled_str) if enabled_str else bool(url) except ValueError: fl_logging.warning( "[Stats] env FL_STATS_ENABLED=%s " "is invalid truth value. Disable stats ", enabled_str) _enabled = False if not _enabled: fl_logging.info("[Stats] stats not enabled") return NoneClient() if not url: fl_logging.warning( "[Stats] FL_STATS_URL not found, redirect to stderr") url = "stderr://" try: return Client(url) except Exception as e: #pylint: disable=broad-except fl_logging.error("[Stats] new client error: %s, redirect to stderr", e) return Client("stderr://")
def WorkerRegister(self, request, context): with self._lock: # for compatibility, more information see: # protocal/fedlearner/common/trainer_master_service.proto if self._worker0_cluster_def is None and request.worker_rank == 0: self._worker0_cluster_def = request.cluster_def if self._status in (tm_pb.MasterStatus.WORKER_COMPLETED, tm_pb.MasterStatus.COMPLETED): return tm_pb.WorkerRegisterResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_DATA_FINISHED)) if self._status != tm_pb.MasterStatus.RUNNING: return tm_pb.WorkerRegisterResponse( status=common_pb.Status( code=common_pb.StatusCode. \ STATUS_WAIT_FOR_SYNCING_CHECKPOINT )) if request.worker_rank in self._running_workers: fl_logging.warning("worker_%d:%s repeat registration", request.worker_rank, request.hostname) else: fl_logging.info("worker_%d:%s registration", request.worker_rank, request.hostname) self._running_workers.add(request.worker_rank) if request.worker_rank in self._completed_workers: self._completed_workers.remove(request.worker_rank) return tm_pb.WorkerRegisterResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_SUCCESS))
def _transmit(self, msg): with self._stream_condition: assert not self._stream_terminated while len(self._stream_queue) == self._stream_queue_size: fl_logging.warning( "[Bridge] transmit stream queue is full, " "size: %d", len(self._stream_queue)) self._stream_condition.wait() self._stream_queue.append(msg) self._stream_condition.notify_all()
def _grpc_with_retry(call, interval=1): while True: try: return call() except grpc.RpcError as e: fl_logging.warning( "TrainerMasterClient error, status: %s" ", details: %s, wait %ds for retry", e.code(), e.details(), interval) time.sleep(interval)
def _grpc_error_get_http_status(details): try: if details.count("http2 header with status") > 0: fields = details.split(":") if len(fields) == 2: return int(details.split(":")[1]) except Exception as e: fl_logging.warning( "[Channel] grpc_error_get_http_status except: %s, details: %s", repr(e), details) return None
def load_data_block(self, count, block_id): req = tws2_pb.LoadDataBlockRequest(count=count, block_id=block_id) fl_logging.debug("[Bridge] sending DataBlock with id %s", block_id) resp = self._client.LoadDataBlock(req) if resp.status.code == common_pb.STATUS_SUCCESS: fl_logging.info("[Bridge] remote succeeded to load data block %s", block_id) return True fl_logging.warning( "[Bridge] remoted failed to load data block %s. " "code: %d", block_id, resp.status.code) return False
def _call_locked(self, call_type): self._lock.release() try: req = channel_pb2.CallRequest( type=call_type, token=self._token, identifier=self._identifier, peer_identifier=self._peer_identifier) timer = self._stats_client.timer("channel.call_timing").start() res = self._channel_call.Call(req, timeout=self._heartbeat_interval, wait_for_ready=True) timer.stop() except Exception as e: self._stats_client.incr("channel.call_error") if isinstance(e, grpc.RpcError): fl_logging.warning( "[Channel] grpc error, code: %s, " "details: %s.(call type: %s)", e.code(), e.details(), channel_pb2.CallType.Name(call_type)) else: fl_logging.error("[Channel] call error: %s.(call type: %s", e, channel_pb2.CallType.Name(call_type)) self._lock.acquire() return False self._lock.acquire() if res.code == channel_pb2.Code.OK: self._refresh_heartbeat_timeout() if call_type == channel_pb2.CallType.CONNECT: self._connected_at = res.timestamp self._emit_event(Channel.Event.CONNECTED) elif call_type == channel_pb2.CallType.CLOSE: self._closed_at = res.timestamp self._emit_event(Channel.Event.CLOSED) return True if res.code == channel_pb2.Code.UNAUTHORIZED: self._emit_event(Channel.Event.ERROR, ChannelError("unauthorized")) elif res.code == channel_pb2.Code.UNIDENTIFIED: if not self._peer_identifier: fl_logging.warning("[Channel] unidentified by peer, " "but channel is clean. wait next retry") else: self._emit_error(ChannelError("unidentified")) else: msg = "unexcepted code: {} for call type: {}".format( channel_pb2.Code.Name(res.code), channel_pb2.CallType.Name(call_type)) self._emit_error(ChannelError(msg)) return False
def recv(self, name, dtype=tf.float32, require_grad=False, shape=None): with tf.control_dependencies([self._example_ids]): receive_op = self._bridge.receive_op(name, dtype) if shape: receive_op = tf.ensure_shape(receive_op, shape) else: fl_logging.warning( 'Receiving tensor %s without checking shape. ' 'Consider setting shape at model.recv(shape=(...)). ' 'shape can have None dimensions ' 'which matches to any length.', name) self._train_ops.append(receive_op) self._recvs.append((name, receive_op, require_grad)) return receive_op
def _grpc_with_retry(call, interval=1): while True: try: result = call() if not result.running() and result.exception() is not None: raise result.exception() return result except grpc.RpcError as e: if _grpc_error_need_recover(e): fl_logging.warning( "[Channel] grpc error, status: %s, " "details: %s, wait %ds for retry", e.code(), e.details(), interval) time.sleep(interval) continue raise e
def _receive(self, name): start_time = time.time() alert_count = 0 with self._condition: self._assert_iter_started() iter_id = self._current_iter_id while (iter_id not in self._received_data or name not in self._received_data[iter_id]): self._assert_iter_started() if iter_id != self._current_iter_id: raise RuntimeError( "[Bridge] iter change while waiting receive data, " "iter_id: {}, name: {}".format(iter_id, name)) if self._peer_commit_iter_id is not None \ and iter_id <= self._peer_commit_iter_id: raise RuntimeError( "[Bridge] peer committed without sending data " "iter_id: {}, name: {}".format(iter_id, name)) if self._peer_terminated: raise RuntimeError( "[Bridge] peer terminated without sending data " "iter_id: {}, name: {}".format(iter_id, name)) duration = time.time() - start_time if duration >= (alert_count + 1) * self._waiting_alert_timeout: alert_count += 1 fl_logging.warning( "[Bridge] Data: waiting to receive " "iter_id: %d, name: %s timeout. duration: %f sec", iter_id, name, duration) wait_timeout = self._waiting_alert_timeout - \ (duration % self._waiting_alert_timeout) self._condition.wait(wait_timeout) data = self._received_data[iter_id][name] duration = time.time() - start_time _gctx.stats_client.timing("trainer.bridge.receive_timing", duration * 1000, {"bridge_receive_name": name}) fl_logging.debug( "[Bridge] Data: received iter_id: %d, name: %s " "after %f sec", iter_id, name, duration) return data
def response_iterator(init_stream_response): stream_response = init_stream_response while True: try: for response in stream_response: self._check_fn(response) if consumer.ack(response.ack): yield method_details.response_deserializer( response.payload) return except grpc.RpcError as e: if _grpc_error_need_recover(e): fl_logging.warning( "[Channel] grpc error, status: %s, " "details: %s, wait %ds for retry", e.code(), e.details(), self._retry_interval) time.sleep(self._retry_interval) stream_response = _grpc_with_retry( call, self._retry_interval) continue raise e
def convert_to_datetime(value, enable_tz=False): """ Args: value: datetime object | bytes | str | int | float. Value to be converted. Expected to be a numeric in the format of yyyymmdd or yyyymmddhhnnss, or a datetime object. enable_tz: bool. whether converts to UTC and contains timezone info Returns: str. Try to convert a datetime str or numeric to a UTC iso format str. 1. Try to convert based on the length of str. 2. Try to convert assuming it is a timestamp. 3. If it does not match any pattern, return iso format of timestamp=0. Timezone will be set according to system TZ env if unset and then converted back to UTC if enable_tz is True. """ assert isinstance(value, (bytes, str, int, float)) if isinstance(value, bytes): value = value.decode() elif isinstance(value, (int, float)): value = str(value) # 1. try to parse datetime from value try: date_time = convert_time_string_to_datetime(value) except ValueError: # Not fitting any of above patterns # 2. try to convert assuming it is a timestamp # not in the same `try` block b/c the length of some strings might # be equal to 8 or 14 but it does not match any of the patterns try: date_time = datetime.datetime.fromtimestamp(float(value)) except ValueError: # might be a non-number str # 3. default to 0 fl_logging.warning( 'Unable to parse time %s to iso format, ' 'defaults to 0.', value) date_time = INVALID_DATETIME if enable_tz: date_time = set_timezone(date_time) return date_time
def _transmit_handler(self, request): with self._condition: if request.HasField("start"): if self._peer_commit_iter_id is not None \ and request.start.iter_id <= self._peer_commit_iter_id: fl_logging.warning( "[Bridge] received peer start iter_id: %d " "which has been committed. " "maybe caused by resend.(peer_commit_iter_id: %d)", request.start.iter_id, self._peer_commit_iter_id) elif self._peer_start_iter_id is not None: fl_logging.warning( "[Bridge] received repeated peer start iter_id: %d. " "maybe caused by resend.(peer_start_iter_id: %d)", request.start.iter_id, self._peer_start_iter_id) else: fl_logging.debug( "[Bridge] received peer start iter_id: %d", request.start.iter_id) self._peer_start_iter_id = request.start.iter_id self._condition.notify_all() elif request.HasField("data"): if self._peer_start_iter_id is None: fl_logging.warning( "[Bridge] received data iter_id: %d without start. " "maybe caused by resend.", request.data.iter_id) elif self._peer_start_iter_id != request.data.iter_id: fl_logging.warning( "[Bridge] received data iter_id: %d no match start. " "maybe caused by resend.(peer_start_iter_id: %d)", request.data.iter_id, self._peer_start_iter_id) else: iter_id = self._current_iter_id \ if self._current_iter_id is not None \ else self._next_iter_id if request.data.iter_id < iter_id: fl_logging.debug( "[Bridge] received data iter_id: %d, " "name: %s, ignored by our commit." "(current_iter_id: %s, next_iter_id: %d)", request.data.iter_id, request.data.name, self._current_iter_id, self._next_iter_id) else: fl_logging.debug( "[Bridge] received data iter_id: %d, " "name: %s", request.data.iter_id, request.data.name) self._received_data[ \ request.data.iter_id][ \ request.data.name] = request.data self._condition.notify_all() elif request.HasField("commit"): if self._peer_commit_iter_id is not None \ and request.commit.iter_id <= self._peer_commit_iter_id: fl_logging.warning( "[Bridge] receive repeated peer commit iter_id: %d. " "maybe caused by resend.(peer_commit_iter_id: %d)", request.commit.iter_id, self._peer_commit_iter_id) elif self._peer_start_iter_id is None: fl_logging.error( "[Bridge] receive peer commit iter_id: %d " "without start", request.commit.iter_id) # return error? elif request.commit.iter_id != self._peer_start_iter_id: fl_logging.error( "[Bridge] receive peer commit iter_id: %s " "no match start.(peer_start_iter_id: %d)", request.commit.iter_id, self._peer_start_iter_id) # return error? else: fl_logging.debug( "[Bridge] receive peer commit iter_id: %d", request.commit.iter_id) self._peer_start_iter_id = None self._peer_commit_iter_id = request.commit.iter_id self._condition.notify_all() return tws2_pb.TransmitResponse(status=common_pb.Status( code=common_pb.STATUS_SUCCESS))
def _state_fn(self): fl_logging.debug("[Channel] thread _state_fn start") self._server.start() #self._channel.subscribe(self._channel_callback) self._lock.acquire() while True: now = time.time() saved_state = self._state wait_timeout = 10 self._stats_client.gauge("channel.status", self._state.value) if self._state in (Channel.State.DONE, Channel.State.ERROR): break # check disconnected if self._state not in Channel._CONNECTING_STATES: if now >= self._heartbeat_timeout_at: self._emit_error( ChannelError( "disconnected by heartbeat timeout: {}s".format( self._heartbeat_timeout))) continue wait_timeout = min(wait_timeout, self._heartbeat_timeout_at - now) # check peer disconnected if self._state not in Channel._PEER_UNCONNECTED_STATES: if now >= self._peer_heartbeat_timeout_at: self._emit_error( ChannelError( "peer disconnected by heartbeat timeout: {}s". format(self._heartbeat_timeout))) continue wait_timeout = min(wait_timeout, self._peer_heartbeat_timeout_at - now) if now >= self._next_retry_at: self._next_retry_at = 0 if self._state in Channel._CONNECTING_STATES: if self._call_locked(channel_pb2.CallType.CONNECT): self._emit_event(Channel.Event.CONNECTED) self._refresh_heartbeat_timeout() else: self._next_retry_at = time.time( ) + self._retry_interval continue if self._state in Channel._CLOSING_STATES: if self._call_locked(channel_pb2.CallType.CLOSE): self._emit_event(Channel.Event.CLOSED) self._refresh_heartbeat_timeout() else: self._next_retry_at = 0 # fast retry continue if now >= self._next_heartbeat_at: if self._call_locked(channel_pb2.CallType.HEARTBEAT): fl_logging.debug("[Channel] call heartbeat OK") self._refresh_heartbeat_timeout() else: fl_logging.warning("[Channel] call heartbeat failed") interval = min(self._heartbeat_interval, self._retry_interval) self._next_retry_at = time.time() + interval continue wait_timeout = min(wait_timeout, self._next_heartbeat_at - now) else: wait_timeout = min(wait_timeout, self._next_retry_at - now) if saved_state != self._state: continue self._condition.wait(wait_timeout) # done self._lock.release() self._channel.close() time_wait = 2 * self._retry_interval + self._peer_closed_at - time.time( ) if time_wait > 0: fl_logging.info( "[Channel] wait %0.2f sec " "for reducing peer close failed", time_wait) time.sleep(time_wait) self._server.stop(10) self._server.wait_for_termination() self._ready_event.set() self._closed_event.set() fl_logging.debug("[Channel] thread _state_fn stop")