def build_results(self, session, tensor_values): """Build results matching the original fetch shape. `tensor_values` must be a list of the same length as the one returned by `fetches()`, and holding the requested fetch values. This method builds a struct with the same shape as the original `fetches` passed to the constructor, in which the fetches are replaced by their fetched value. Args: session: The enclosing session. Used for tensor handles. tensor_values: List of values matching the list returned by fetches(). Returns: A structure of the same shape as the original `fetches` argument but containing tensors or None (for fetched ops). """ full_values = [] assert len(self._fetches) == len(tensor_values) i = 0 for is_op in self._ops: if is_op: full_values.append(None) else: dtype = self._fetch_handles.get(self._fetches[i]) if dtype: full_values.append(session_ops.TensorHandle( tensor_values[i], dtype, session)) else: full_values.append(tensor_values[i]) i += 1 return self._fetch_mapper.build_results(full_values)
def _run(self, handle, fetches, feed_dict, options, run_metadata): """Perform either run or partial_run, depending the exitence of `handle`.""" def _feed_fn(feed, feed_val): for tensor_type, _, feed_fn, _ in BaseSession._REGISTERED_EXPANSIONS: if isinstance(feed, tensor_type): return feed_fn(feed, feed_val) raise TypeError('Feed argument %r has invalid type %r' % (feed, type(feed))) # Check session. if self._closed: raise RuntimeError('Attempted to use a closed Session.') if self.graph.version == 0: raise RuntimeError( 'The Session graph is empty. Add operations to the ' 'graph before calling run().') # Flatten/unflatten fetched values. if isinstance(fetches, (list, tuple)): # fetches is already a list or tuple; nothing to do. orig_fetches, fetches = fetches, nest.flatten(fetches) unflatten = lambda fetched: nest.pack_sequence_as( orig_fetches, fetched) elif isinstance(fetches, dict): # fetches is a dictionary; flatten the values and map fetched # values back into to a dictionary. # nest.flatten does not accept iterators, next line is for python3 # compatibility. fetches_values = list(fetches.values()) orig_fetches, fetches = fetches, nest.flatten(fetches_values) unflatten = lambda fetched: _unflatten_fetches( orig_fetches, fetched) else: # fetches is a singleton. fetches = [fetches] unflatten = lambda fetched: fetched[0] # Validate and process fetches. processed_fetches = self._process_fetches(fetches) unique_fetches = processed_fetches[0] target_list = processed_fetches[1] fetch_info = processed_fetches[2] unique_handles = processed_fetches[3] # Create request. feed_dict_string = {} feed_map = {} # Validate and process feed_dict. if feed_dict: feed_dict = nest.flatten_dict_items(feed_dict) for feed, feed_val in feed_dict.items(): for subfeed, subfeed_val in _feed_fn(feed, feed_val): try: subfeed_t = self.graph.as_graph_element( subfeed, allow_tensor=True, allow_operation=False) except Exception as e: raise TypeError( 'Cannot interpret feed_dict key as Tensor: ' + e.args[0]) if isinstance(subfeed_val, ops.Tensor): raise TypeError( 'The value of a feed cannot be a tf.Tensor object. ' 'Acceptable feed values include Python scalars, ' 'strings, lists, or numpy ndarrays.') subfeed_dtype = subfeed_t.dtype.as_numpy_dtype if isinstance( subfeed_val, int) and subfeed_dtype(subfeed_val) != subfeed_val: raise TypeError( 'Type of feed value ' + str(subfeed_val) + ' is not' ' compatible with Tensor type ' + str(subfeed_dtype) + '.' ' Try explicitly setting the type of the feed tensor' ' to a larger type (e.g. int64).') np_val = np.asarray(subfeed_val, dtype=subfeed_dtype) if not subfeed_t.get_shape().is_compatible_with( np_val.shape): raise ValueError( 'Cannot feed value of shape %r for Tensor %r, ' 'which has shape %r' % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) if not self.graph.is_feedable(subfeed_t): raise ValueError('Tensor %s may not be fed.' % subfeed_t) subfeed_name = compat.as_bytes(subfeed_t.name) feed_dict_string[subfeed_name] = np_val feed_map[subfeed_name] = (subfeed_t, subfeed_val) # Run request and get response. # We need to keep the movers alive for the following _do_run(). # These movers are no longer needed when _do_run() completes, and # are deleted when `movers` goes out of scope when this _run() ends. # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding # of a handle from a different device as an error. movers = self._update_with_movers(feed_dict_string, feed_map) results = self._do_run(handle, target_list, unique_fetches, feed_dict_string, options, run_metadata) # User may have fetched the same tensor multiple times, but we # only fetch them from the runtime once. Furthermore, they may # be wrapped as a tuple of tensors. Here we map the results back # to what the client asked for. # TODO(yuanbyu): Use the contraction_fn in _REGISTERED_EXPANSIONS. fetched_results = {} for fetch, result in zip(unique_fetches, results): dtype = unique_handles.get(fetch) if dtype: result = session_ops.TensorHandle(result, dtype, self) fetched_results[fetch] = result ret = [] for fetch_names, fetch_contraction_fn in fetch_info: if fetch_names: fetched_vals = [fetched_results[name] for name in fetch_names] ret.append(fetch_contraction_fn(fetched_vals)) else: ret.append(None) return unflatten(ret)
def _run(self, handle, fetches, feed_dict, options, run_metadata): """Perform either run or partial_run, depending the exitence of `handle`.""" def _feed_fn(feed, feed_val): for tensor_type, _, feed_fn, _ in BaseSession._REGISTERED_EXPANSIONS: if isinstance(feed, tensor_type): return feed_fn(feed, feed_val) raise TypeError('Feed argument %r has invalid type %r' % (feed, type(feed))) # Check session. if self._closed: raise RuntimeError('Attempted to use a closed Session.') if self.graph.version == 0: raise RuntimeError('The Session graph is empty. Add operations to the ' 'graph before calling run().') # Validate and process fetches. processed_fetches = self._process_fetches(fetches) unique_fetches = processed_fetches[0] target_list = processed_fetches[1] fetch_info = processed_fetches[2] unique_handles = processed_fetches[3] # Create request. feed_dict_string = {} feed_map = {} # Validate and process feed_dict. if feed_dict: for feed, feed_val in feed_dict.items(): for subfeed, subfeed_val in _feed_fn(feed, feed_val): try: subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True, allow_operation=False) except Exception as e: raise TypeError('Cannot interpret feed_dict key as Tensor: ' + e.args[0]) if isinstance(subfeed_val, ops.Tensor): raise TypeError('The value of a feed cannot be a tf.Tensor object. ' 'Acceptable feed values include Python scalars, ' 'strings, lists, or numpy ndarrays.') subfeed_dtype = subfeed_t.dtype.as_numpy_dtype if isinstance(subfeed_val, int) and subfeed_dtype(subfeed_val) != subfeed_val: raise TypeError( 'Type of feed value ' + str(subfeed_val) + ' is not' ' compatible with Tensor type ' + str(subfeed_dtype) + '.' ' Try explicitly setting the type of the feed tensor' ' to a larger type (e.g. int64).') np_val = np.array(subfeed_val, dtype=subfeed_dtype) if not subfeed_t.get_shape().is_compatible_with(np_val.shape): raise ValueError( 'Cannot feed value of shape %r for Tensor %r, ' 'which has shape %r' % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) if not self.graph.is_feedable(subfeed_t): raise ValueError('Tensor %s may not be fed.' % subfeed_t) subfeed_name = compat.as_bytes(subfeed_t.name) feed_dict_string[subfeed_name] = np_val feed_map[subfeed_name] = (subfeed_t, subfeed_val) # Run request and get response. movers = self._update_with_movers(feed_dict_string, feed_map) try: results = self._do_run(handle, target_list, unique_fetches, feed_dict_string, options, run_metadata) finally: # The movers are no longer used. Delete them. for handle in movers: self._register_dead_handle(handle) # User may have fetched the same tensor multiple times, but we # only fetch them from the runtime once. Furthermore, they may # be wrapped as a tuple of tensors. Here we map the results back # to what the client asked for. # TODO(yuanbyu): Use the contraction_fn in _REGISTERED_EXPANSIONS. fetched_results = {} for fetch, result in zip(unique_fetches, results): dtype = unique_handles.get(fetch) if dtype: result = session_ops.TensorHandle(result, dtype, self) fetched_results[fetch] = result ret = [] for fetch_names, fetch_contraction_fn in fetch_info: if fetch_names: fetched_vals = [fetched_results[name] for name in fetch_names] ret.append(fetch_contraction_fn(fetched_vals)) else: ret.append(None) if isinstance(fetches, (list, tuple)): return ret else: return ret[0]