예제 #1
0
    def _do_run(self, target_list, fetch_list, feed_dict):
        """Runs a step based on the given fetches and feeds.

    Args:
      target_list: A list of byte arrays corresponding to names of tensors
        or operations to be run to, but not fetched.
      fetch_list: A list of byte arrays corresponding to names of tensors to
        be fetched and operations to be run.
      feed_dict: A dictionary that maps tensor names (as byte arrays) to
        numpy ndarrays.

    Returns:
      A list of numpy ndarrays, corresponding to the elements of
      `fetch_list`.  If the ith element of `fetch_list` contains the
      name of an operation, the first Tensor output of that operation
      will be returned for that element.
    """
        try:
            # Ensure any changes to the graph are reflected in the runtime.
            with self._extend_lock:
                if self._graph.version > self._current_version:
                    graph_def = self._graph.as_graph_def(
                        from_version=self._current_version)

                    try:
                        status = tf_session.TF_NewStatus()
                        tf_session.TF_ExtendGraph(
                            self._session, graph_def.SerializeToString(),
                            status)
                        if tf_session.TF_GetCode(status) != 0:
                            raise RuntimeError(
                                compat.as_text(tf_session.TF_Message(status)))
                        self._opened = True
                    finally:
                        tf_session.TF_DeleteStatus(status)

                    self._current_version = self._graph.version

            return tf_session.TF_Run(self._session, feed_dict, fetch_list,
                                     target_list)

        except tf_session.StatusNotOK as e:
            e_type, e_value, e_traceback = sys.exc_info()
            error_message = compat.as_text(e.error_message)
            m = BaseSession._NODEDEF_NAME_RE.search(error_message)
            if m is not None:
                node_name = m.group(1)
                node_def = None
                try:
                    op = self._graph.get_operation_by_name(node_name)
                    node_def = op.node_def
                except KeyError:
                    op = None
                # pylint: disable=protected-access
                raise errors._make_specific_exception(node_def, op,
                                                      error_message, e.code)
                # pylint: enable=protected-access
            six.reraise(e_type, e_value, e_traceback)
예제 #2
0
    def _extend_graph(self):
        # Ensure any changes to the graph are reflected in the runtime.
        with self._extend_lock:
            if self._graph.version > self._current_version:
                graph_def = self._graph.as_graph_def(
                    from_version=self._current_version)

                with errors.raise_exception_on_not_ok_status() as status:
                    tf_session.TF_ExtendGraph(self._session,
                                              graph_def.SerializeToString(),
                                              status)
                self._opened = True

                self._current_version = self._graph.version
예제 #3
0
  def _extend_graph(self):
    # Ensure any changes to the graph are reflected in the runtime.
    with self._extend_lock:
      if self._graph.version > self._current_version:
        # pylint: disable=protected-access
        graph_def, self._current_version = self._graph._as_graph_def(
            from_version=self._current_version,
            add_shapes=self._add_shapes)
        # pylint: enable=protected-access

        with errors.raise_exception_on_not_ok_status() as status:
          tf_session.TF_ExtendGraph(
              self._session, graph_def.SerializeToString(), status)
        self._opened = True
예제 #4
0
  def _extend_graph(self):
    # Ensure any changes to the graph are reflected in the runtime.
    with self._extend_lock:
      if self._graph.version > self._current_version:
        graph_def = self._graph.as_graph_def(
            from_version=self._current_version)

        try:
          status = tf_session.TF_NewStatus()
          tf_session.TF_ExtendGraph(
              self._session, graph_def.SerializeToString(), status)
          if tf_session.TF_GetCode(status) != 0:
            raise RuntimeError(compat.as_text(tf_session.TF_Message(status)))
          self._opened = True
        finally:
          tf_session.TF_DeleteStatus(status)

        self._current_version = self._graph.version