def _run_fn(session, feed_dict, fetch_list, target_list, options, run_outputs): # Ensure any changes to the graph are reflected in the runtime. self._extend_graph() if options: return tf_session.TF_Run(session, options, feed_dict, fetch_list, target_list, run_outputs) else: return tf_session.TF_Run(session, None, feed_dict, fetch_list, target_list, None)
def _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata): # Ensure any changes to the graph are reflected in the runtime. self._extend_graph() with errors.raise_exception_on_not_ok_status() as status: if options: return tf_session.TF_Run(session, options, feed_dict, fetch_list, target_list, status, run_metadata) else: return tf_session.TF_Run(session, None, feed_dict, fetch_list, target_list, status, None)
def _do_run(self, target_list, fetch_list, feed_dict): """Runs a step based on the given fetches and feeds. Args: target_list: A list of byte arrays corresponding to names of tensors or operations to be run to, but not fetched. fetch_list: A list of byte arrays corresponding to names of tensors to be fetched and operations to be run. feed_dict: A dictionary that maps tensor names (as byte arrays) to numpy ndarrays. Returns: A list of numpy ndarrays, corresponding to the elements of `fetch_list`. If the ith element of `fetch_list` contains the name of an operation, the first Tensor output of that operation will be returned for that element. """ try: # Ensure any changes to the graph are reflected in the runtime. with self._extend_lock: if self._graph.version > self._current_version: graph_def = self._graph.as_graph_def( from_version=self._current_version) try: status = tf_session.TF_NewStatus() tf_session.TF_ExtendGraph( self._session, graph_def.SerializeToString(), status) if tf_session.TF_GetCode(status) != 0: raise RuntimeError( compat.as_text(tf_session.TF_Message(status))) self._opened = True finally: tf_session.TF_DeleteStatus(status) self._current_version = self._graph.version return tf_session.TF_Run(self._session, feed_dict, fetch_list, target_list) except tf_session.StatusNotOK as e: e_type, e_value, e_traceback = sys.exc_info() error_message = compat.as_text(e.error_message) m = BaseSession._NODEDEF_NAME_RE.search(error_message) if m is not None: node_name = m.group(1) node_def = None try: op = self._graph.get_operation_by_name(node_name) node_def = op.node_def except KeyError: op = None # pylint: disable=protected-access raise errors._make_specific_exception(node_def, op, error_message, e.code) # pylint: enable=protected-access six.reraise(e_type, e_value, e_traceback)
def fast_tf(): return tf_session.TF_Run(session, options, feed_dict, fetch_list, target_list, status, run_metadata)
def _run_fn(session, feed_dict, fetch_list, target_list): # Ensure any changes to the graph are reflected in the runtime. self._extend_graph() return tf_session.TF_Run(session, feed_dict, fetch_list, target_list)