예제 #1
0
 def _check_token(self, token):
     if self._token != token:
         fl_logging.debug(
             "[Channel] peer unauthorized, got token: '%s', "
             "want: '%s'", token, self._token)
         return False
     return True
예제 #2
0
    def evaluate(self, input_fn):

        with tf.Graph().as_default() as g, \
            g.device(self._cluster_server.device_setter):

            features, labels = self._get_features_and_labels_from_input_fn(
                input_fn, tf.estimator.ModeKeys.EVAL)
            spec, model = self._get_model_spec(features, labels,
                                               tf.estimator.ModeKeys.EVAL)

            # Track the average loss in default
            eval_metric_ops = spec.eval_metric_ops or {}
            if model_fn_lib.LOSS_METRIC_KEY not in eval_metric_ops:
                loss_metric = tf.metrics.mean(spec.loss)
                eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric

            # Create the real eval op
            update_ops, eval_dict = _extract_metric_update_ops(eval_metric_ops)
            update_ops.extend(model._train_ops)
            eval_op = tf.group(*update_ops)

            # Also track the global step
            if tf.GraphKeys.GLOBAL_STEP in eval_dict:
                raise ValueError(
                    'Metric with name `global_step` is not allowed, because '
                    'Estimator already defines a default metric with the '
                    'same name.')
            eval_dict[tf.GraphKeys.GLOBAL_STEP] = \
                tf.train.get_or_create_global_step()

            # Prepare hooks
            all_hooks = []
            if spec.evaluation_hooks:
                all_hooks.extend(spec.evaluation_hooks)
            final_ops_hook = tf.train.FinalOpsHook(eval_dict)
            all_hooks.append(final_ops_hook)

            session_creator = tf.train.WorkerSessionCreator(
                master=self._cluster_server.target,
                config=self._cluster_server.cluster_config)
            # Evaluate over dataset
            self._bridge.connect()
            with tf.train.MonitoredSession(session_creator=session_creator,
                                           hooks=all_hooks) as sess:
                while not sess.should_stop():
                    start_time = time.time()
                    self._bridge.start()
                    sess.run(eval_op)
                    self._bridge.commit()
                    use_time = time.time() - start_time
                    fl_logging.debug("after session run. time: %f sec",
                                     use_time)
            self._bridge.terminate()

            # Print result
            fl_logging.info('Metrics for evaluate: %s',
                            _dict_to_str(final_ops_hook.final_ops_values))

        return self
예제 #3
0
 def load_data_block(self, count, block_id):
     req = tws2_pb.LoadDataBlockRequest(count=count, block_id=block_id)
     fl_logging.debug("[Bridge] sending DataBlock with id %s", block_id)
     resp = self._client.LoadDataBlock(req)
     if resp.status.code == common_pb.STATUS_SUCCESS:
         fl_logging.info("[Bridge] remote succeeded to load data block %s",
                         block_id)
         return True
     fl_logging.warning(
         "[Bridge] remoted failed to load data block %s. "
         "code: %d", block_id, resp.status.code)
     return False
예제 #4
0
    def start(self):
        with self._condition:
            self._assert_iter_committed()

            self._current_iter_id = self._next_iter_id
            self._next_iter_id += 1
            self._iter_started_at = time.time()
            self._transmit(
                tws2_pb.TransmitRequest(
                    start=tws2_pb.TransmitRequest.StartMessage(
                        iter_id=self._current_iter_id)))
            fl_logging.debug("[Bridge] send start iter_id: %d",
                             self._current_iter_id)
예제 #5
0
    def _send(self, name, tensor=None, any_data=None):
        with self._condition:
            self._assert_iter_started()

            self._transmit(
                tws2_pb.TransmitRequest(
                    data=tws2_pb.TransmitRequest.DataMessage(
                        iter_id=self._current_iter_id,
                        name=name,
                        tensor=tensor,
                        any_data=any_data,
                    )))
            fl_logging.debug("[Bridge] Data: send iter_id: %d, name: %s",
                             self._current_iter_id, name)
예제 #6
0
    def _receive(self, name):
        start_time = time.time()
        alert_count = 0
        with self._condition:
            self._assert_iter_started()
            iter_id = self._current_iter_id
            while (iter_id not in self._received_data
                   or name not in self._received_data[iter_id]):
                self._assert_iter_started()
                if iter_id != self._current_iter_id:
                    raise RuntimeError(
                        "[Bridge] iter change while waiting receive data, "
                        "iter_id: {}, name: {}".format(iter_id, name))
                if self._peer_commit_iter_id is not None \
                    and iter_id <= self._peer_commit_iter_id:
                    raise RuntimeError(
                        "[Bridge] peer committed without sending data "
                        "iter_id: {}, name: {}".format(iter_id, name))
                if self._peer_terminated:
                    raise RuntimeError(
                        "[Bridge] peer terminated without sending data "
                        "iter_id: {}, name: {}".format(iter_id, name))
                duration = time.time() - start_time
                if duration >= (alert_count + 1) * self._waiting_alert_timeout:
                    alert_count += 1
                    fl_logging.warning(
                        "[Bridge] Data: waiting to receive "
                        "iter_id: %d, name: %s timeout. duration: %f sec",
                        iter_id, name, duration)
                wait_timeout = self._waiting_alert_timeout - \
                    (duration % self._waiting_alert_timeout)
                self._condition.wait(wait_timeout)
            data = self._received_data[iter_id][name]

        duration = time.time() - start_time
        _gctx.stats_client.timing("trainer.bridge.receive_timing",
                                  duration * 1000,
                                  {"bridge_receive_name": name})
        fl_logging.debug(
            "[Bridge] Data: received iter_id: %d, name: %s "
            "after %f sec", iter_id, name, duration)
        return data
예제 #7
0
    def commit(self):
        with self._condition:
            self._assert_iter_started()

            self._transmit(
                tws2_pb.TransmitRequest(
                    commit=tws2_pb.TransmitRequest.CommitMessage(
                        iter_id=self._current_iter_id)))
            fl_logging.debug("[Bridge] send commit iter_id: %d",
                             self._current_iter_id)
            # delete committed data
            if self._current_iter_id in self._received_data:
                del self._received_data[self._current_iter_id]
            iter_id = self._current_iter_id
            duration = (time.time() - self._iter_started_at) * 1000
            self._current_iter_id = None

        with _gctx.stats_client.pipeline() as pipe:
            pipe.gauge("trainer.bridge.iterator_step", iter_id)
            pipe.timing("trainer.bridge.iterator_timing", duration)
예제 #8
0
 def request_data_block(self, block_id):
     request = tm_pb.DataBlockRequest(worker_rank=self._worker_rank,
                                      block_id=block_id)
     response = _grpc_with_retry(
         lambda: self._client.RequestDataBlock(request))
     if response.status.code == common_pb.StatusCode.STATUS_SUCCESS:
         fl_logging.debug(
             "succeeded to get datablock, id:%s, data_path: %s",
             response.block_id, response.data_path)
         return response
     if response.status.code == \
         common_pb.StatusCode.STATUS_INVALID_DATA_BLOCK:
         fl_logging.error("invalid data block id: %s", request.block_id)
         return None
     if response.status.code == common_pb.StatusCode.STATUS_DATA_FINISHED:
         fl_logging.info("data block finished")
         return None
     raise RuntimeError("RequestDataBlock error, code: %s, msg: %s"% \
         (common_pb.StatusCode.Name(response.status.code),
          response.status.error_message))
예제 #9
0
 def request_iterator():
     fl_logging.debug("[Bridge] stream transmitting")
     while True:
         with self._stream_condition:
             if len(self._stream_queue) == 0:
                 start = time.time()
                 while len(self._stream_queue) == 0:
                     duration = time.time() - start
                     if self._stream_terminated:
                         return
                     if duration >= IDLE_TIMEOUT:
                         fl_logging.debug(
                             "[Bridge] stream transmit "
                             " closed by idle timeout: %f sec",
                             IDLE_TIMEOUT)
                         return
                     self._stream_condition.wait(IDLE_TIMEOUT -
                                                 duration)
             msg = self._stream_queue.popleft()
             self._stream_condition.notify_all()
         yield msg
예제 #10
0
    def train(self, input_fn):

        with tf.Graph().as_default() as g, \
            g.device(self._cluster_server.device_setter):

            features, labels = self._get_features_and_labels_from_input_fn(
                input_fn, tf.estimator.ModeKeys.TRAIN)
            spec, _ = self._get_model_spec(
                features, labels, tf.estimator.ModeKeys.TRAIN)

            hooks = []
            # stats
            hooks.append(TraceStatsHook(
                every_secs=30, stats_client=_gctx.stats_client))
            # user define chief hook
            if spec.training_chief_hooks and self._is_chief:
                hooks.extend(spec.training_chief_hooks)

            if spec.training_hooks:
                hooks.extend(spec.training_hooks)

            session_creator = tf.train.WorkerSessionCreator(
                master=self._cluster_server.target,
                config=self._cluster_server.cluster_config)

            self._bridge.connect()
            with tf.train.MonitoredSession(
                session_creator=session_creator, hooks=hooks) as sess:
                while not sess.should_stop():
                    start_time = time.time()
                    self._bridge.start()
                    sess.run(spec.train_op, feed_dict={})
                    self._bridge.commit()
                    use_time = time.time() - start_time
                    fl_logging.debug("after session run. time: %f sec",
                                     use_time)
            self._bridge.terminate()

        return self
예제 #11
0
    def _stream_transmit_fn(self):
        IDLE_TIMEOUT = 30
        fl_logging.debug("[Bridge] stream transmit started")

        def request_iterator():
            fl_logging.debug("[Bridge] stream transmitting")
            while True:
                with self._stream_condition:
                    if len(self._stream_queue) == 0:
                        start = time.time()
                        while len(self._stream_queue) == 0:
                            duration = time.time() - start
                            if self._stream_terminated:
                                return
                            if duration >= IDLE_TIMEOUT:
                                fl_logging.debug(
                                    "[Bridge] stream transmit "
                                    " closed by idle timeout: %f sec",
                                    IDLE_TIMEOUT)
                                return
                            self._stream_condition.wait(IDLE_TIMEOUT -
                                                        duration)
                    msg = self._stream_queue.popleft()
                    self._stream_condition.notify_all()
                yield msg

        while True:
            with self._stream_condition:
                while len(self._stream_queue) == 0:
                    if self._stream_terminated:
                        fl_logging.debug("[Bridge] stream transmit closed")
                        return
                    self._stream_condition.wait()
            response_iterator = \
                self._client.Transmit(request_iterator())
            for _ in response_iterator:
                pass
예제 #12
0
    def _transmit_handler(self, request):
        with self._condition:
            if request.HasField("start"):
                if self._peer_commit_iter_id is not None \
                    and request.start.iter_id <= self._peer_commit_iter_id:
                    fl_logging.warning(
                        "[Bridge] received peer start iter_id: %d "
                        "which has been committed. "
                        "maybe caused by resend.(peer_commit_iter_id: %d)",
                        request.start.iter_id, self._peer_commit_iter_id)
                elif self._peer_start_iter_id is not None:
                    fl_logging.warning(
                        "[Bridge] received repeated peer start iter_id: %d. "
                        "maybe caused by resend.(peer_start_iter_id: %d)",
                        request.start.iter_id, self._peer_start_iter_id)
                else:
                    fl_logging.debug(
                        "[Bridge] received peer start iter_id: %d",
                        request.start.iter_id)
                    self._peer_start_iter_id = request.start.iter_id
                    self._condition.notify_all()

            elif request.HasField("data"):
                if self._peer_start_iter_id is None:
                    fl_logging.warning(
                        "[Bridge] received data iter_id: %d without start. "
                        "maybe caused by resend.", request.data.iter_id)
                elif self._peer_start_iter_id != request.data.iter_id:
                    fl_logging.warning(
                        "[Bridge] received data iter_id: %d no match start. "
                        "maybe caused by resend.(peer_start_iter_id: %d)",
                        request.data.iter_id, self._peer_start_iter_id)
                else:
                    iter_id = self._current_iter_id \
                        if self._current_iter_id is not None \
                            else self._next_iter_id
                    if request.data.iter_id < iter_id:
                        fl_logging.debug(
                            "[Bridge] received data iter_id: %d, "
                            "name: %s, ignored by our commit."
                            "(current_iter_id: %s, next_iter_id: %d)",
                            request.data.iter_id, request.data.name,
                            self._current_iter_id, self._next_iter_id)
                    else:
                        fl_logging.debug(
                            "[Bridge] received data iter_id: %d, "
                            "name: %s", request.data.iter_id,
                            request.data.name)
                        self._received_data[ \
                            request.data.iter_id][ \
                                request.data.name] = request.data
                        self._condition.notify_all()

            elif request.HasField("commit"):
                if self._peer_commit_iter_id is not None \
                    and request.commit.iter_id <= self._peer_commit_iter_id:
                    fl_logging.warning(
                        "[Bridge] receive repeated peer commit iter_id: %d. "
                        "maybe caused by resend.(peer_commit_iter_id: %d)",
                        request.commit.iter_id, self._peer_commit_iter_id)
                elif self._peer_start_iter_id is None:
                    fl_logging.error(
                        "[Bridge] receive peer commit iter_id: %d "
                        "without start", request.commit.iter_id)
                    # return error?
                elif request.commit.iter_id != self._peer_start_iter_id:
                    fl_logging.error(
                        "[Bridge] receive peer commit iter_id: %s "
                        "no match start.(peer_start_iter_id: %d)",
                        request.commit.iter_id, self._peer_start_iter_id)
                    # return error?
                else:
                    fl_logging.debug(
                        "[Bridge] receive peer commit iter_id: %d",
                        request.commit.iter_id)
                    self._peer_start_iter_id = None
                    self._peer_commit_iter_id = request.commit.iter_id
                    self._condition.notify_all()

        return tws2_pb.TransmitResponse(status=common_pb.Status(
            code=common_pb.STATUS_SUCCESS))
예제 #13
0
    def _state_fn(self):
        fl_logging.debug("[Channel] thread _state_fn start")

        self._server.start()
        #self._channel.subscribe(self._channel_callback)

        self._lock.acquire()
        while True:
            now = time.time()
            saved_state = self._state
            wait_timeout = 10

            self._stats_client.gauge("channel.status", self._state.value)
            if self._state in (Channel.State.DONE, Channel.State.ERROR):
                break

            # check disconnected
            if self._state not in Channel._CONNECTING_STATES:
                if now >= self._heartbeat_timeout_at:
                    self._emit_error(
                        ChannelError(
                            "disconnected by heartbeat timeout: {}s".format(
                                self._heartbeat_timeout)))
                    continue
                wait_timeout = min(wait_timeout,
                                   self._heartbeat_timeout_at - now)

            # check peer disconnected
            if self._state not in Channel._PEER_UNCONNECTED_STATES:
                if now >= self._peer_heartbeat_timeout_at:
                    self._emit_error(
                        ChannelError(
                            "peer disconnected by heartbeat timeout: {}s".
                            format(self._heartbeat_timeout)))
                    continue
                wait_timeout = min(wait_timeout,
                                   self._peer_heartbeat_timeout_at - now)

            if now >= self._next_retry_at:
                self._next_retry_at = 0
                if self._state in Channel._CONNECTING_STATES:
                    if self._call_locked(channel_pb2.CallType.CONNECT):
                        self._emit_event(Channel.Event.CONNECTED)
                        self._refresh_heartbeat_timeout()
                    else:
                        self._next_retry_at = time.time(
                        ) + self._retry_interval
                    continue
                if self._state in Channel._CLOSING_STATES:
                    if self._call_locked(channel_pb2.CallType.CLOSE):
                        self._emit_event(Channel.Event.CLOSED)
                        self._refresh_heartbeat_timeout()
                    else:
                        self._next_retry_at = 0  # fast retry
                    continue
                if now >= self._next_heartbeat_at:
                    if self._call_locked(channel_pb2.CallType.HEARTBEAT):
                        fl_logging.debug("[Channel] call heartbeat OK")
                        self._refresh_heartbeat_timeout()
                    else:
                        fl_logging.warning("[Channel] call heartbeat failed")
                        interval = min(self._heartbeat_interval,
                                       self._retry_interval)
                        self._next_retry_at = time.time() + interval
                    continue

                wait_timeout = min(wait_timeout, self._next_heartbeat_at - now)
            else:
                wait_timeout = min(wait_timeout, self._next_retry_at - now)

            if saved_state != self._state:
                continue

            self._condition.wait(wait_timeout)

        # done
        self._lock.release()

        self._channel.close()
        time_wait = 2 * self._retry_interval + self._peer_closed_at - time.time(
        )
        if time_wait > 0:
            fl_logging.info(
                "[Channel] wait %0.2f sec "
                "for reducing peer close failed", time_wait)
            time.sleep(time_wait)

        self._server.stop(10)
        self._server.wait_for_termination()

        self._ready_event.set()
        self._closed_event.set()

        fl_logging.debug("[Channel] thread _state_fn stop")