Esempio n. 1
0
def _create_client():
    """ Enable stats from env """
    global _enabled  #pylint: disable=global-statement

    enabled_str = os.getenv("FL_STATS_ENABLED")
    url = os.getenv("FL_STATS_URL")
    try:
        _enabled = strtobool(enabled_str) if enabled_str else bool(url)
    except ValueError:
        fl_logging.warning(
            "[Stats] env FL_STATS_ENABLED=%s "
            "is invalid truth value. Disable stats ", enabled_str)
        _enabled = False

    if not _enabled:
        fl_logging.info("[Stats] stats not enabled")
        return NoneClient()

    if not url:
        fl_logging.warning(
            "[Stats] FL_STATS_URL not found, redirect to stderr")
        url = "stderr://"

    try:
        return Client(url)
    except Exception as e:  #pylint: disable=broad-except
        fl_logging.error("[Stats] new client error: %s, redirect to stderr", e)
        return Client("stderr://")
Esempio n. 2
0
    def WorkerRegister(self, request, context):
        with self._lock:
            # for compatibility, more information see:
            #   protocal/fedlearner/common/trainer_master_service.proto
            if self._worker0_cluster_def is None and request.worker_rank == 0:
                self._worker0_cluster_def = request.cluster_def

            if self._status in (tm_pb.MasterStatus.WORKER_COMPLETED,
                                tm_pb.MasterStatus.COMPLETED):
                return tm_pb.WorkerRegisterResponse(status=common_pb.Status(
                    code=common_pb.StatusCode.STATUS_DATA_FINISHED))

            if self._status != tm_pb.MasterStatus.RUNNING:
                return tm_pb.WorkerRegisterResponse(
                    status=common_pb.Status(
                        code=common_pb.StatusCode. \
                            STATUS_WAIT_FOR_SYNCING_CHECKPOINT
                    ))

            if request.worker_rank in self._running_workers:
                fl_logging.warning("worker_%d:%s repeat registration",
                                   request.worker_rank, request.hostname)
            else:
                fl_logging.info("worker_%d:%s registration",
                                request.worker_rank, request.hostname)

            self._running_workers.add(request.worker_rank)
            if request.worker_rank in self._completed_workers:
                self._completed_workers.remove(request.worker_rank)
            return tm_pb.WorkerRegisterResponse(status=common_pb.Status(
                code=common_pb.StatusCode.STATUS_SUCCESS))
Esempio n. 3
0
 def _transmit(self, msg):
     with self._stream_condition:
         assert not self._stream_terminated
         while len(self._stream_queue) == self._stream_queue_size:
             fl_logging.warning(
                 "[Bridge] transmit stream queue is full, "
                 "size: %d", len(self._stream_queue))
             self._stream_condition.wait()
         self._stream_queue.append(msg)
         self._stream_condition.notify_all()
Esempio n. 4
0
def _grpc_with_retry(call, interval=1):
    while True:
        try:
            return call()
        except grpc.RpcError as e:
            fl_logging.warning(
                "TrainerMasterClient error, status: %s"
                ", details: %s, wait %ds for retry", e.code(), e.details(),
                interval)
            time.sleep(interval)
Esempio n. 5
0
def _grpc_error_get_http_status(details):
    try:
        if details.count("http2 header with status") > 0:
            fields = details.split(":")
            if len(fields) == 2:
                return int(details.split(":")[1])
    except Exception as e:
        fl_logging.warning(
            "[Channel] grpc_error_get_http_status except: %s, details: %s",
            repr(e), details)
    return None
Esempio n. 6
0
 def load_data_block(self, count, block_id):
     req = tws2_pb.LoadDataBlockRequest(count=count, block_id=block_id)
     fl_logging.debug("[Bridge] sending DataBlock with id %s", block_id)
     resp = self._client.LoadDataBlock(req)
     if resp.status.code == common_pb.STATUS_SUCCESS:
         fl_logging.info("[Bridge] remote succeeded to load data block %s",
                         block_id)
         return True
     fl_logging.warning(
         "[Bridge] remoted failed to load data block %s. "
         "code: %d", block_id, resp.status.code)
     return False
Esempio n. 7
0
    def _call_locked(self, call_type):
        self._lock.release()
        try:
            req = channel_pb2.CallRequest(
                type=call_type,
                token=self._token,
                identifier=self._identifier,
                peer_identifier=self._peer_identifier)
            timer = self._stats_client.timer("channel.call_timing").start()
            res = self._channel_call.Call(req,
                                          timeout=self._heartbeat_interval,
                                          wait_for_ready=True)
            timer.stop()
        except Exception as e:
            self._stats_client.incr("channel.call_error")
            if isinstance(e, grpc.RpcError):
                fl_logging.warning(
                    "[Channel] grpc error, code: %s, "
                    "details: %s.(call type: %s)", e.code(), e.details(),
                    channel_pb2.CallType.Name(call_type))
            else:
                fl_logging.error("[Channel] call error: %s.(call type: %s", e,
                                 channel_pb2.CallType.Name(call_type))
            self._lock.acquire()
            return False
        self._lock.acquire()

        if res.code == channel_pb2.Code.OK:
            self._refresh_heartbeat_timeout()
            if call_type == channel_pb2.CallType.CONNECT:
                self._connected_at = res.timestamp
                self._emit_event(Channel.Event.CONNECTED)
            elif call_type == channel_pb2.CallType.CLOSE:
                self._closed_at = res.timestamp
                self._emit_event(Channel.Event.CLOSED)
            return True

        if res.code == channel_pb2.Code.UNAUTHORIZED:
            self._emit_event(Channel.Event.ERROR, ChannelError("unauthorized"))
        elif res.code == channel_pb2.Code.UNIDENTIFIED:
            if not self._peer_identifier:
                fl_logging.warning("[Channel] unidentified by peer, "
                                   "but channel is clean. wait next retry")
            else:
                self._emit_error(ChannelError("unidentified"))
        else:
            msg = "unexcepted code: {} for call type: {}".format(
                channel_pb2.Code.Name(res.code),
                channel_pb2.CallType.Name(call_type))
            self._emit_error(ChannelError(msg))
        return False
Esempio n. 8
0
 def recv(self, name, dtype=tf.float32, require_grad=False, shape=None):
     with tf.control_dependencies([self._example_ids]):
         receive_op = self._bridge.receive_op(name, dtype)
     if shape:
         receive_op = tf.ensure_shape(receive_op, shape)
     else:
         fl_logging.warning(
             'Receiving tensor %s without checking shape. '
             'Consider setting shape at model.recv(shape=(...)). '
             'shape can have None dimensions '
             'which matches to any length.', name)
     self._train_ops.append(receive_op)
     self._recvs.append((name, receive_op, require_grad))
     return receive_op
Esempio n. 9
0
def _grpc_with_retry(call, interval=1):
    while True:
        try:
            result = call()
            if not result.running() and result.exception() is not None:
                raise result.exception()
            return result
        except grpc.RpcError as e:
            if _grpc_error_need_recover(e):
                fl_logging.warning(
                    "[Channel] grpc error, status: %s, "
                    "details: %s, wait %ds for retry", e.code(), e.details(),
                    interval)
                time.sleep(interval)
                continue
            raise e
Esempio n. 10
0
    def _receive(self, name):
        start_time = time.time()
        alert_count = 0
        with self._condition:
            self._assert_iter_started()
            iter_id = self._current_iter_id
            while (iter_id not in self._received_data
                   or name not in self._received_data[iter_id]):
                self._assert_iter_started()
                if iter_id != self._current_iter_id:
                    raise RuntimeError(
                        "[Bridge] iter change while waiting receive data, "
                        "iter_id: {}, name: {}".format(iter_id, name))
                if self._peer_commit_iter_id is not None \
                    and iter_id <= self._peer_commit_iter_id:
                    raise RuntimeError(
                        "[Bridge] peer committed without sending data "
                        "iter_id: {}, name: {}".format(iter_id, name))
                if self._peer_terminated:
                    raise RuntimeError(
                        "[Bridge] peer terminated without sending data "
                        "iter_id: {}, name: {}".format(iter_id, name))
                duration = time.time() - start_time
                if duration >= (alert_count + 1) * self._waiting_alert_timeout:
                    alert_count += 1
                    fl_logging.warning(
                        "[Bridge] Data: waiting to receive "
                        "iter_id: %d, name: %s timeout. duration: %f sec",
                        iter_id, name, duration)
                wait_timeout = self._waiting_alert_timeout - \
                    (duration % self._waiting_alert_timeout)
                self._condition.wait(wait_timeout)
            data = self._received_data[iter_id][name]

        duration = time.time() - start_time
        _gctx.stats_client.timing("trainer.bridge.receive_timing",
                                  duration * 1000,
                                  {"bridge_receive_name": name})
        fl_logging.debug(
            "[Bridge] Data: received iter_id: %d, name: %s "
            "after %f sec", iter_id, name, duration)
        return data
Esempio n. 11
0
 def response_iterator(init_stream_response):
     stream_response = init_stream_response
     while True:
         try:
             for response in stream_response:
                 self._check_fn(response)
                 if consumer.ack(response.ack):
                     yield method_details.response_deserializer(
                         response.payload)
             return
         except grpc.RpcError as e:
             if _grpc_error_need_recover(e):
                 fl_logging.warning(
                     "[Channel] grpc error, status: %s, "
                     "details: %s, wait %ds for retry", e.code(),
                     e.details(), self._retry_interval)
                 time.sleep(self._retry_interval)
                 stream_response = _grpc_with_retry(
                     call, self._retry_interval)
                 continue
             raise e
Esempio n. 12
0
def convert_to_datetime(value, enable_tz=False):
    """
    Args:
        value: datetime object | bytes | str | int | float.
            Value to be converted. Expected to be a numeric in the format of
            yyyymmdd or yyyymmddhhnnss, or a datetime object.
        enable_tz: bool. whether converts to UTC and contains timezone info

    Returns: str.
    Try to convert a datetime str or numeric to a UTC iso format str.
        1. Try to convert based on the length of str.
        2. Try to convert assuming it is a timestamp.
        3. If it does not match any pattern, return iso format of timestamp=0.
        Timezone will be set according to system TZ env if unset and
        then converted back to UTC if enable_tz is True.
    """
    assert isinstance(value, (bytes, str, int, float))
    if isinstance(value, bytes):
        value = value.decode()
    elif isinstance(value, (int, float)):
        value = str(value)
    # 1. try to parse datetime from value
    try:
        date_time = convert_time_string_to_datetime(value)
    except ValueError:  # Not fitting any of above patterns
        # 2. try to convert assuming it is a timestamp
        # not in the same `try` block b/c the length of some strings might
        # be equal to 8 or 14 but it does not match any of the patterns
        try:
            date_time = datetime.datetime.fromtimestamp(float(value))
        except ValueError:  # might be a non-number str
            # 3. default to 0
            fl_logging.warning(
                'Unable to parse time %s to iso format, '
                'defaults to 0.', value)
            date_time = INVALID_DATETIME
    if enable_tz:
        date_time = set_timezone(date_time)
    return date_time
Esempio n. 13
0
    def _transmit_handler(self, request):
        with self._condition:
            if request.HasField("start"):
                if self._peer_commit_iter_id is not None \
                    and request.start.iter_id <= self._peer_commit_iter_id:
                    fl_logging.warning(
                        "[Bridge] received peer start iter_id: %d "
                        "which has been committed. "
                        "maybe caused by resend.(peer_commit_iter_id: %d)",
                        request.start.iter_id, self._peer_commit_iter_id)
                elif self._peer_start_iter_id is not None:
                    fl_logging.warning(
                        "[Bridge] received repeated peer start iter_id: %d. "
                        "maybe caused by resend.(peer_start_iter_id: %d)",
                        request.start.iter_id, self._peer_start_iter_id)
                else:
                    fl_logging.debug(
                        "[Bridge] received peer start iter_id: %d",
                        request.start.iter_id)
                    self._peer_start_iter_id = request.start.iter_id
                    self._condition.notify_all()

            elif request.HasField("data"):
                if self._peer_start_iter_id is None:
                    fl_logging.warning(
                        "[Bridge] received data iter_id: %d without start. "
                        "maybe caused by resend.", request.data.iter_id)
                elif self._peer_start_iter_id != request.data.iter_id:
                    fl_logging.warning(
                        "[Bridge] received data iter_id: %d no match start. "
                        "maybe caused by resend.(peer_start_iter_id: %d)",
                        request.data.iter_id, self._peer_start_iter_id)
                else:
                    iter_id = self._current_iter_id \
                        if self._current_iter_id is not None \
                            else self._next_iter_id
                    if request.data.iter_id < iter_id:
                        fl_logging.debug(
                            "[Bridge] received data iter_id: %d, "
                            "name: %s, ignored by our commit."
                            "(current_iter_id: %s, next_iter_id: %d)",
                            request.data.iter_id, request.data.name,
                            self._current_iter_id, self._next_iter_id)
                    else:
                        fl_logging.debug(
                            "[Bridge] received data iter_id: %d, "
                            "name: %s", request.data.iter_id,
                            request.data.name)
                        self._received_data[ \
                            request.data.iter_id][ \
                                request.data.name] = request.data
                        self._condition.notify_all()

            elif request.HasField("commit"):
                if self._peer_commit_iter_id is not None \
                    and request.commit.iter_id <= self._peer_commit_iter_id:
                    fl_logging.warning(
                        "[Bridge] receive repeated peer commit iter_id: %d. "
                        "maybe caused by resend.(peer_commit_iter_id: %d)",
                        request.commit.iter_id, self._peer_commit_iter_id)
                elif self._peer_start_iter_id is None:
                    fl_logging.error(
                        "[Bridge] receive peer commit iter_id: %d "
                        "without start", request.commit.iter_id)
                    # return error?
                elif request.commit.iter_id != self._peer_start_iter_id:
                    fl_logging.error(
                        "[Bridge] receive peer commit iter_id: %s "
                        "no match start.(peer_start_iter_id: %d)",
                        request.commit.iter_id, self._peer_start_iter_id)
                    # return error?
                else:
                    fl_logging.debug(
                        "[Bridge] receive peer commit iter_id: %d",
                        request.commit.iter_id)
                    self._peer_start_iter_id = None
                    self._peer_commit_iter_id = request.commit.iter_id
                    self._condition.notify_all()

        return tws2_pb.TransmitResponse(status=common_pb.Status(
            code=common_pb.STATUS_SUCCESS))
Esempio n. 14
0
    def _state_fn(self):
        fl_logging.debug("[Channel] thread _state_fn start")

        self._server.start()
        #self._channel.subscribe(self._channel_callback)

        self._lock.acquire()
        while True:
            now = time.time()
            saved_state = self._state
            wait_timeout = 10

            self._stats_client.gauge("channel.status", self._state.value)
            if self._state in (Channel.State.DONE, Channel.State.ERROR):
                break

            # check disconnected
            if self._state not in Channel._CONNECTING_STATES:
                if now >= self._heartbeat_timeout_at:
                    self._emit_error(
                        ChannelError(
                            "disconnected by heartbeat timeout: {}s".format(
                                self._heartbeat_timeout)))
                    continue
                wait_timeout = min(wait_timeout,
                                   self._heartbeat_timeout_at - now)

            # check peer disconnected
            if self._state not in Channel._PEER_UNCONNECTED_STATES:
                if now >= self._peer_heartbeat_timeout_at:
                    self._emit_error(
                        ChannelError(
                            "peer disconnected by heartbeat timeout: {}s".
                            format(self._heartbeat_timeout)))
                    continue
                wait_timeout = min(wait_timeout,
                                   self._peer_heartbeat_timeout_at - now)

            if now >= self._next_retry_at:
                self._next_retry_at = 0
                if self._state in Channel._CONNECTING_STATES:
                    if self._call_locked(channel_pb2.CallType.CONNECT):
                        self._emit_event(Channel.Event.CONNECTED)
                        self._refresh_heartbeat_timeout()
                    else:
                        self._next_retry_at = time.time(
                        ) + self._retry_interval
                    continue
                if self._state in Channel._CLOSING_STATES:
                    if self._call_locked(channel_pb2.CallType.CLOSE):
                        self._emit_event(Channel.Event.CLOSED)
                        self._refresh_heartbeat_timeout()
                    else:
                        self._next_retry_at = 0  # fast retry
                    continue
                if now >= self._next_heartbeat_at:
                    if self._call_locked(channel_pb2.CallType.HEARTBEAT):
                        fl_logging.debug("[Channel] call heartbeat OK")
                        self._refresh_heartbeat_timeout()
                    else:
                        fl_logging.warning("[Channel] call heartbeat failed")
                        interval = min(self._heartbeat_interval,
                                       self._retry_interval)
                        self._next_retry_at = time.time() + interval
                    continue

                wait_timeout = min(wait_timeout, self._next_heartbeat_at - now)
            else:
                wait_timeout = min(wait_timeout, self._next_retry_at - now)

            if saved_state != self._state:
                continue

            self._condition.wait(wait_timeout)

        # done
        self._lock.release()

        self._channel.close()
        time_wait = 2 * self._retry_interval + self._peer_closed_at - time.time(
        )
        if time_wait > 0:
            fl_logging.info(
                "[Channel] wait %0.2f sec "
                "for reducing peer close failed", time_wait)
            time.sleep(time_wait)

        self._server.stop(10)
        self._server.wait_for_termination()

        self._ready_event.set()
        self._closed_event.set()

        fl_logging.debug("[Channel] thread _state_fn stop")