Exemple #1
0
def _create_client():
    """ Enable stats from env """
    global _enabled  #pylint: disable=global-statement

    enabled_str = os.getenv("FL_STATS_ENABLED")
    url = os.getenv("FL_STATS_URL")
    try:
        _enabled = strtobool(enabled_str) if enabled_str else bool(url)
    except ValueError:
        fl_logging.warning(
            "[Stats] env FL_STATS_ENABLED=%s "
            "is invalid truth value. Disable stats ", enabled_str)
        _enabled = False

    if not _enabled:
        fl_logging.info("[Stats] stats not enabled")
        return NoneClient()

    if not url:
        fl_logging.warning(
            "[Stats] FL_STATS_URL not found, redirect to stderr")
        url = "stderr://"

    try:
        return Client(url)
    except Exception as e:  #pylint: disable=broad-except
        fl_logging.error("[Stats] new client error: %s, redirect to stderr", e)
        return Client("stderr://")
Exemple #2
0
    def WorkerRegister(self, request, context):
        with self._lock:
            # for compatibility, more information see:
            #   protocal/fedlearner/common/trainer_master_service.proto
            if self._worker0_cluster_def is None and request.worker_rank == 0:
                self._worker0_cluster_def = request.cluster_def

            if self._status in (tm_pb.MasterStatus.WORKER_COMPLETED,
                                tm_pb.MasterStatus.COMPLETED):
                return tm_pb.WorkerRegisterResponse(status=common_pb.Status(
                    code=common_pb.StatusCode.STATUS_DATA_FINISHED))

            if self._status != tm_pb.MasterStatus.RUNNING:
                return tm_pb.WorkerRegisterResponse(
                    status=common_pb.Status(
                        code=common_pb.StatusCode. \
                            STATUS_WAIT_FOR_SYNCING_CHECKPOINT
                    ))

            if request.worker_rank in self._running_workers:
                fl_logging.warning("worker_%d:%s repeat registration",
                                   request.worker_rank, request.hostname)
            else:
                fl_logging.info("worker_%d:%s registration",
                                request.worker_rank, request.hostname)

            self._running_workers.add(request.worker_rank)
            if request.worker_rank in self._completed_workers:
                self._completed_workers.remove(request.worker_rank)
            return tm_pb.WorkerRegisterResponse(status=common_pb.Status(
                code=common_pb.StatusCode.STATUS_SUCCESS))
Exemple #3
0
    def evaluate(self, input_fn):

        with tf.Graph().as_default() as g, \
            g.device(self._cluster_server.device_setter):

            features, labels = self._get_features_and_labels_from_input_fn(
                input_fn, tf.estimator.ModeKeys.EVAL)
            spec, model = self._get_model_spec(features, labels,
                                               tf.estimator.ModeKeys.EVAL)

            # Track the average loss in default
            eval_metric_ops = spec.eval_metric_ops or {}
            if model_fn_lib.LOSS_METRIC_KEY not in eval_metric_ops:
                loss_metric = tf.metrics.mean(spec.loss)
                eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric

            # Create the real eval op
            update_ops, eval_dict = _extract_metric_update_ops(eval_metric_ops)
            update_ops.extend(model._train_ops)
            eval_op = tf.group(*update_ops)

            # Also track the global step
            if tf.GraphKeys.GLOBAL_STEP in eval_dict:
                raise ValueError(
                    'Metric with name `global_step` is not allowed, because '
                    'Estimator already defines a default metric with the '
                    'same name.')
            eval_dict[tf.GraphKeys.GLOBAL_STEP] = \
                tf.train.get_or_create_global_step()

            # Prepare hooks
            all_hooks = []
            if spec.evaluation_hooks:
                all_hooks.extend(spec.evaluation_hooks)
            final_ops_hook = tf.train.FinalOpsHook(eval_dict)
            all_hooks.append(final_ops_hook)

            session_creator = tf.train.WorkerSessionCreator(
                master=self._cluster_server.target,
                config=self._cluster_server.cluster_config)
            # Evaluate over dataset
            self._bridge.connect()
            with tf.train.MonitoredSession(session_creator=session_creator,
                                           hooks=all_hooks) as sess:
                while not sess.should_stop():
                    start_time = time.time()
                    self._bridge.start()
                    sess.run(eval_op)
                    self._bridge.commit()
                    use_time = time.time() - start_time
                    fl_logging.debug("after session run. time: %f sec",
                                     use_time)
            self._bridge.terminate()

            # Print result
            fl_logging.info('Metrics for evaluate: %s',
                            _dict_to_str(final_ops_hook.final_ops_values))

        return self
Exemple #4
0
 def before_checkpoint_save(self, session, global_step_value):
     data = self._visitor.dump()
     fl_logging.info(
         "DataVisitor save checkpoint for global step %d, "
         "size: %d", global_step_value, len(data))
     session.run(
         self._save_op,
         {self._ckpt_plhd: data},
     )
Exemple #5
0
 def wait_master_complete(self):
     request = tm_pb.IsCompletedRequest()
     while True:
         response = _grpc_with_retry(
             lambda: self._client.IsCompleted(request))
         if response.completed:
             fl_logging.info("master completed")
             return
         fl_logging.info("waiting master complete...")
         time.sleep(2)
Exemple #6
0
 def _run_grpc_server(self, address):
     self._grpc_server = grpc.server(
         futures.ThreadPoolExecutor(
             max_workers=8,
             thread_name_prefix="TrainerMasterServerThreadPoolExecutor"))
     tm_grpc.add_TrainerMasterServiceServicer_to_server(
         self, self._grpc_server)
     self._grpc_server.add_insecure_port(address)
     self._grpc_server.start()
     fl_logging.info('Trainer Master Server start on address: %s', address)
Exemple #7
0
 def _run(self):
     fl_logging.info("create estimator")
     estimator = self._create_estimator()
     fl_logging.info("start session_run")
     self._session_run(estimator)
     fl_logging.info("session_run done")
     fl_logging.info("start export_model")
     self._export_model(estimator)
     fl_logging.info("export_model done")
     self._transfer_status(tm_pb.MasterStatus.WORKER_COMPLETED,
                           tm_pb.MasterStatus.COMPLETED)
Exemple #8
0
 def load_data_block(self, count, block_id):
     req = tws2_pb.LoadDataBlockRequest(count=count, block_id=block_id)
     fl_logging.debug("[Bridge] sending DataBlock with id %s", block_id)
     resp = self._client.LoadDataBlock(req)
     if resp.status.code == common_pb.STATUS_SUCCESS:
         fl_logging.info("[Bridge] remote succeeded to load data block %s",
                         block_id)
         return True
     fl_logging.warning(
         "[Bridge] remoted failed to load data block %s. "
         "code: %d", block_id, resp.status.code)
     return False
Exemple #9
0
 def _data_block_handler(self, request):
     assert self._data_block_handler_fn is not None, \
         "[Bridge] receive DataBlockMessage but no handler registered."
     if self._data_block_handler_fn(request):
         fl_logging.info("[Bridge] succeeded to load data block %s",
                         request.block_id)
         return tws2_pb.LoadDataBlockResponse(status=common_pb.Status(
             code=common_pb.STATUS_SUCCESS))
     fl_logging.info("[Bridge] failed to load data block %s",
                     request.block_id)
     return tws2_pb.LoadDataBlockResponse(status=common_pb.Status(
         code=common_pb.STATUS_INVALID_DATA_BLOCK))
Exemple #10
0
 def _transfer_status(self, frm, to):
     if self._status != frm:
         raise RuntimeError(
             "Trainer Master status transfer failed, "
             "want from %s to %s, but current status: %s"% \
                         (tm_pb.MasterStatus.Name(frm),
                          tm_pb.MasterStatus.Name(to),
                          tm_pb.MasterStatus.Name(self._status))
             )
     self._status = to
     fl_logging.info("Trainer Master status transfer, from %s to %s",
                     tm_pb.MasterStatus.Name(frm),
                     tm_pb.MasterStatus.Name(to))
Exemple #11
0
 def _supervise_fn(self):
     check_handlers = []
     if self._supervise_iteration_timeout > 0:
         fl_logging.info("enable supervise iteartion timeout: %f",
                         self._supervise_iteration_timeout)
         check_handlers.append(self._check_iteration_timeout)
     if len(check_handlers) == 0:
         return
     while True:
         with self._condition:
             if self._terminated:
                 return
         for handler in check_handlers:
             handler()
         time.sleep(self._supervise_interval)
    def _data_block_handler(self, msg):
        if self._count > msg.count:
            fl_logging.warn('DataBlock: ignore repeated datablock "%s" at %d',
                            msg.block_id, msg.count)
            return True

        fl_logging.info('DataBlock: recv "%s" at %d', msg.block_id, msg.count)
        assert self._count == msg.count
        if not msg.block_id:
            block = None
        else:
            block = self._trainer_master.request_data_block(msg.block_id)
            if block is None:
                return False
        self._count += 1
        self._block_queue.put(block)
        return True
Exemple #13
0
    def _create_tf_server(self, cluster_spec):
        self._tf_config = tf.ConfigProto()
        self._tf_config.inter_op_parallelism_threads = 64
        self._tf_config.intra_op_parallelism_threads = 64
        self._tf_config.experimental \
            .share_session_state_in_clusterspec_propagation = True
        self._tf_config.rpc_options.compression_algorithm = "gzip"
        self._tf_config.rpc_options.cache_rpc_response = True
        self._tf_config.rpc_options.disable_session_connection_sharing = True

        try:
            address = cluster_spec.task_address(
                self._job_name, self._task_index)
            self._tf_server = \
                tf.distribute.Server({"server": {
                                        self._task_index: address}
                                     },
                                     protocol="grpc",
                                     config=self._tf_config)
            self._tf_target = "grpc://" + address
        except ValueError:
            self._tf_server = \
                tf.distribute.Server({"server":
                                        {self._task_index: "localhost:0"}
                                     },
                                     protocol="grpc",
                                     config=self._tf_config)
            self._tf_target = self._tf_server.target

        # modify cluster_spec
        cluster_dict = dict()
        cluster_dict[self._job_name] = {
            self._task_index: self._tf_target[len("grpc://"):]
        }
        for job_name in cluster_spec.jobs:
            if job_name == self._job_name:
                continue
            if job_name in self._extra_reserve_jobs:
                cluster_dict[job_name] = cluster_spec.job_tasks(job_name)

        self._tf_cluster_spec = tf.train.ClusterSpec(cluster_dict)
        self._tf_config.cluster_def.CopyFrom(
            self._tf_cluster_spec.as_cluster_def())

        fl_logging.info("cluster server target: %s\nconfig: \n%s",
                        self._tf_target, self._tf_config)
Exemple #14
0
 def _emit_event(self, event, error=None):
     with self._lock:
         next_state = self._next_state(self._state, event)
         if self._state != next_state:
             fl_logging.info(
                 "[Channel] state changed from %s to %s, "
                 "event: %s", self._state.name, next_state.name, event.name)
             self._state = next_state
             if self._state == Channel.State.ERROR:
                 assert error is not None
                 self._error_event.set()
                 fl_logging.error("[Channel] occur error: %s", str(error))
                 self._error = error
             elif self._state == Channel.State.READY:
                 self._ready_event.set()
             self._event_callback(event)
             self._condition.notify_all()
Exemple #15
0
 def _request_data_block(self, request):
     data_block = self._data_visitor.get_datablock_by_id(request.block_id)
     if data_block:
         fl_logging.info("allocated worker_%d with block: %s",
                         request.worker_rank, data_block.id)
         response = tm_pb.DataBlockResponse(
             status=common_pb.Status(
                 code=common_pb.StatusCode.STATUS_SUCCESS),
             block_id=data_block.id,
             data_path=data_block.data_path,
         )
     else:
         fl_logging.error("invalid data block id: %s", request.block_id)
         response = tm_pb.DataBlockResponse(status=common_pb.Status(
             code=common_pb.StatusCode.STATUS_INVALID_DATA_BLOCK,
             error_message="invalid data block"))
     return response
Exemple #16
0
 def _trigger_fn(self, global_step):
     now = time.time()
     if self._last_global_step >= 0:
         speed = (global_step-self._last_global_step) \
             / (now-self._last_trigger_time)
         total_datablock = self._data_visitor.datablock_size
         fl_logging.info(
             "global_step: %d, speed: %0.2f step/sec, "
             "datablock size: %d, "
             "worker: %d/%d(running/completed)", global_step, speed,
             total_datablock, len(self._running_workers),
             len(self._completed_workers))
         with _gctx.stats_client.pipeline() as pipe:
             pipe.gauge("trainer.global_step", global_step)
             pipe.gauge("trainer.datablock_total", total_datablock)
             pipe.gauge("trainer.speed", speed)
     self._last_trigger_time = now
     self._last_global_step = global_step
Exemple #17
0
    def WorkerComplete(self, request, context):
        with self._lock:
            if request.worker_rank not in self._running_workers:
                return tm_pb.WorkerRegisterResponse(status=common_pb.Status(
                    code=common_pb.StatusCode.STATUS_INVALID_REQUEST,
                    error_message="unregistered worker"))
            fl_logging.info("worker_%d completed", request.worker_rank)
            self._completed_workers.add(request.worker_rank)
            if request.worker_rank == 0:
                self._worker0_terminated_at = request.timestamp

            if len(self._running_workers) == len(self._completed_workers) \
                and 0 in self._running_workers:
                # worker 0 completed and all datablock has finished
                self._transfer_status(tm_pb.MasterStatus.RUNNING,
                                      tm_pb.MasterStatus.WORKER_COMPLETED)

            return tm_pb.WorkerCompleteResponse(status=common_pb.Status(
                code=common_pb.StatusCode.STATUS_SUCCESS))
Exemple #18
0
 def request_data_block(self, block_id):
     request = tm_pb.DataBlockRequest(worker_rank=self._worker_rank,
                                      block_id=block_id)
     response = _grpc_with_retry(
         lambda: self._client.RequestDataBlock(request))
     if response.status.code == common_pb.StatusCode.STATUS_SUCCESS:
         fl_logging.debug(
             "succeeded to get datablock, id:%s, data_path: %s",
             response.block_id, response.data_path)
         return response
     if response.status.code == \
         common_pb.StatusCode.STATUS_INVALID_DATA_BLOCK:
         fl_logging.error("invalid data block id: %s", request.block_id)
         return None
     if response.status.code == common_pb.StatusCode.STATUS_DATA_FINISHED:
         fl_logging.info("data block finished")
         return None
     raise RuntimeError("RequestDataBlock error, code: %s, msg: %s"% \
         (common_pb.StatusCode.Name(response.status.code),
          response.status.error_message))
Exemple #19
0
 def worker_register(self, cluster_def=None):
     request = tm_pb.WorkerRegisterRequest(worker_rank=self._worker_rank,
                                           hostname=os.uname().nodename,
                                           cluster_def=cluster_def)
     while True:
         response = _grpc_with_retry(
             lambda: self._client.WorkerRegister(request))
         if response.status.code == common_pb.StatusCode.STATUS_SUCCESS:
             return True
         if response.status.code == \
             common_pb.StatusCode.STATUS_WAIT_FOR_SYNCING_CHECKPOINT:
             fl_logging.info("waiting master ready...")
             time.sleep(1)
             continue
         if response.status.code == \
             common_pb.StatusCode.STATUS_DATA_FINISHED:
             fl_logging.info("master completed, ignore worker register")
             return False
         raise RuntimeError("WorkerRegister error, code: %s, msg: %s"% \
             (common_pb.StatusCode.Name(response.status.code),
             response.status.error_message))
Exemple #20
0
    def run_forever(self, listen_port=None):
        with self._lock:
            self._transfer_status(tm_pb.MasterStatus.CREATED,
                                  tm_pb.MasterStatus.INITIALING)

        if listen_port:
            self._run_grpc_server(listen_port)

        while self._cluster_server is None:
            # waiting receive cluster_def from worker0
            with self._lock:
                if self._worker0_cluster_def:
                    fl_logging.info("received worker_0 cluster_def: %s",
                                    self._worker0_cluster_def)
                    self._cluster_server = ClusterServer(
                        tf.train.ClusterSpec(self._worker0_cluster_def),
                        "master")
                    break
            fl_logging.info("still waiting receive cluster_def from worker_0")
            time.sleep(2)

        self._run()

        sig = signal.sigwait([signal.SIGHUP, signal.SIGINT, signal.SIGTERM])
        fl_logging.info("Server shutdown by signal: %s",
                        signal.Signals(sig).name)
Exemple #21
0
    def _request_data_block(self, request):
        try:
            data_block = next(self._data_visitor)
        except StopIteration:
            data_block = None

        response = tm_pb.DataBlockResponse()
        if data_block:
            fl_logging.info("allocated worker_%d with block: %s",
                            request.worker_rank, data_block.id)
            response = tm_pb.DataBlockResponse(
                status=common_pb.Status(
                    code=common_pb.StatusCode.STATUS_SUCCESS),
                block_id=data_block.id,
                data_path=data_block.data_path,
            )
        else:
            response = tm_pb.DataBlockResponse(status=common_pb.Status(
                code=common_pb.StatusCode.STATUS_DATA_FINISHED,
                error_message="data block finished"))

        return response
Exemple #22
0
    def _state_fn(self):
        fl_logging.debug("[Channel] thread _state_fn start")

        self._server.start()
        #self._channel.subscribe(self._channel_callback)

        self._lock.acquire()
        while True:
            now = time.time()
            saved_state = self._state
            wait_timeout = 10

            self._stats_client.gauge("channel.status", self._state.value)
            if self._state in (Channel.State.DONE, Channel.State.ERROR):
                break

            # check disconnected
            if self._state not in Channel._CONNECTING_STATES:
                if now >= self._heartbeat_timeout_at:
                    self._emit_error(
                        ChannelError(
                            "disconnected by heartbeat timeout: {}s".format(
                                self._heartbeat_timeout)))
                    continue
                wait_timeout = min(wait_timeout,
                                   self._heartbeat_timeout_at - now)

            # check peer disconnected
            if self._state not in Channel._PEER_UNCONNECTED_STATES:
                if now >= self._peer_heartbeat_timeout_at:
                    self._emit_error(
                        ChannelError(
                            "peer disconnected by heartbeat timeout: {}s".
                            format(self._heartbeat_timeout)))
                    continue
                wait_timeout = min(wait_timeout,
                                   self._peer_heartbeat_timeout_at - now)

            if now >= self._next_retry_at:
                self._next_retry_at = 0
                if self._state in Channel._CONNECTING_STATES:
                    if self._call_locked(channel_pb2.CallType.CONNECT):
                        self._emit_event(Channel.Event.CONNECTED)
                        self._refresh_heartbeat_timeout()
                    else:
                        self._next_retry_at = time.time(
                        ) + self._retry_interval
                    continue
                if self._state in Channel._CLOSING_STATES:
                    if self._call_locked(channel_pb2.CallType.CLOSE):
                        self._emit_event(Channel.Event.CLOSED)
                        self._refresh_heartbeat_timeout()
                    else:
                        self._next_retry_at = 0  # fast retry
                    continue
                if now >= self._next_heartbeat_at:
                    if self._call_locked(channel_pb2.CallType.HEARTBEAT):
                        fl_logging.debug("[Channel] call heartbeat OK")
                        self._refresh_heartbeat_timeout()
                    else:
                        fl_logging.warning("[Channel] call heartbeat failed")
                        interval = min(self._heartbeat_interval,
                                       self._retry_interval)
                        self._next_retry_at = time.time() + interval
                    continue

                wait_timeout = min(wait_timeout, self._next_heartbeat_at - now)
            else:
                wait_timeout = min(wait_timeout, self._next_retry_at - now)

            if saved_state != self._state:
                continue

            self._condition.wait(wait_timeout)

        # done
        self._lock.release()

        self._channel.close()
        time_wait = 2 * self._retry_interval + self._peer_closed_at - time.time(
        )
        if time_wait > 0:
            fl_logging.info(
                "[Channel] wait %0.2f sec "
                "for reducing peer close failed", time_wait)
            time.sleep(time_wait)

        self._server.stop(10)
        self._server.wait_for_termination()

        self._ready_event.set()
        self._closed_event.set()

        fl_logging.debug("[Channel] thread _state_fn stop")