예제 #1
0
 def _run_fn(session, feed_dict, fetch_list, target_list, options,
             run_outputs):
     # Ensure any changes to the graph are reflected in the runtime.
     self._extend_graph()
     if options:
         return tf_session.TF_Run(session, options, feed_dict,
                                  fetch_list, target_list, run_outputs)
     else:
         return tf_session.TF_Run(session, None, feed_dict, fetch_list,
                                  target_list, None)
예제 #2
0
 def _run_fn(session, feed_dict, fetch_list, target_list, options,
             run_metadata):
     # Ensure any changes to the graph are reflected in the runtime.
     self._extend_graph()
     with errors.raise_exception_on_not_ok_status() as status:
         if options:
             return tf_session.TF_Run(session, options, feed_dict,
                                      fetch_list, target_list, status,
                                      run_metadata)
         else:
             return tf_session.TF_Run(session, None, feed_dict,
                                      fetch_list, target_list, status,
                                      None)
예제 #3
0
    def _do_run(self, target_list, fetch_list, feed_dict):
        """Runs a step based on the given fetches and feeds.

    Args:
      target_list: A list of byte arrays corresponding to names of tensors
        or operations to be run to, but not fetched.
      fetch_list: A list of byte arrays corresponding to names of tensors to
        be fetched and operations to be run.
      feed_dict: A dictionary that maps tensor names (as byte arrays) to
        numpy ndarrays.

    Returns:
      A list of numpy ndarrays, corresponding to the elements of
      `fetch_list`.  If the ith element of `fetch_list` contains the
      name of an operation, the first Tensor output of that operation
      will be returned for that element.
    """
        try:
            # Ensure any changes to the graph are reflected in the runtime.
            with self._extend_lock:
                if self._graph.version > self._current_version:
                    graph_def = self._graph.as_graph_def(
                        from_version=self._current_version)

                    try:
                        status = tf_session.TF_NewStatus()
                        tf_session.TF_ExtendGraph(
                            self._session, graph_def.SerializeToString(),
                            status)
                        if tf_session.TF_GetCode(status) != 0:
                            raise RuntimeError(
                                compat.as_text(tf_session.TF_Message(status)))
                        self._opened = True
                    finally:
                        tf_session.TF_DeleteStatus(status)

                    self._current_version = self._graph.version

            return tf_session.TF_Run(self._session, feed_dict, fetch_list,
                                     target_list)

        except tf_session.StatusNotOK as e:
            e_type, e_value, e_traceback = sys.exc_info()
            error_message = compat.as_text(e.error_message)
            m = BaseSession._NODEDEF_NAME_RE.search(error_message)
            if m is not None:
                node_name = m.group(1)
                node_def = None
                try:
                    op = self._graph.get_operation_by_name(node_name)
                    node_def = op.node_def
                except KeyError:
                    op = None
                # pylint: disable=protected-access
                raise errors._make_specific_exception(node_def, op,
                                                      error_message, e.code)
                # pylint: enable=protected-access
            six.reraise(e_type, e_value, e_traceback)
예제 #4
0
def fast_tf():
    return tf_session.TF_Run(session, options,
                             feed_dict, fetch_list, target_list,
                             status, run_metadata)
예제 #5
0
 def _run_fn(session, feed_dict, fetch_list, target_list):
     # Ensure any changes to the graph are reflected in the runtime.
     self._extend_graph()
     return tf_session.TF_Run(session, feed_dict, fetch_list,
                              target_list)