示例#1
0
  def build_results(self, session, tensor_values):
    """Build results matching the original fetch shape.

    `tensor_values` must be a list of the same length as
    the one returned by `fetches()`, and holding the requested
    fetch values.

    This method builds a struct with the same shape as the original `fetches`
    passed to the constructor, in which the fetches are replaced by their
    fetched value.

    Args:
      session: The enclosing session.  Used for tensor handles.
      tensor_values: List of values matching the list returned
        by fetches().

    Returns:
      A structure of the same shape as the original `fetches` argument but
        containing tensors or None (for fetched ops).
    """
    full_values = []
    assert len(self._fetches) == len(tensor_values)
    i = 0
    for is_op in self._ops:
      if is_op:
        full_values.append(None)
      else:
        dtype = self._fetch_handles.get(self._fetches[i])
        if dtype:
          full_values.append(session_ops.TensorHandle(
              tensor_values[i], dtype, session))
        else:
          full_values.append(tensor_values[i])
        i += 1
    return self._fetch_mapper.build_results(full_values)
示例#2
0
    def _run(self, handle, fetches, feed_dict, options, run_metadata):
        """Perform either run or partial_run, depending the exitence of `handle`."""
        def _feed_fn(feed, feed_val):
            for tensor_type, _, feed_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
                if isinstance(feed, tensor_type):
                    return feed_fn(feed, feed_val)
            raise TypeError('Feed argument %r has invalid type %r' %
                            (feed, type(feed)))

        # Check session.
        if self._closed:
            raise RuntimeError('Attempted to use a closed Session.')
        if self.graph.version == 0:
            raise RuntimeError(
                'The Session graph is empty.  Add operations to the '
                'graph before calling run().')

        # Flatten/unflatten fetched values.
        if isinstance(fetches, (list, tuple)):
            # fetches is already a list or tuple; nothing to do.
            orig_fetches, fetches = fetches, nest.flatten(fetches)
            unflatten = lambda fetched: nest.pack_sequence_as(
                orig_fetches, fetched)
        elif isinstance(fetches, dict):
            # fetches is a dictionary; flatten the values and map fetched
            # values back into to a dictionary.
            # nest.flatten does not accept iterators, next line is for python3
            # compatibility.
            fetches_values = list(fetches.values())
            orig_fetches, fetches = fetches, nest.flatten(fetches_values)
            unflatten = lambda fetched: _unflatten_fetches(
                orig_fetches, fetched)
        else:
            # fetches is a singleton.
            fetches = [fetches]
            unflatten = lambda fetched: fetched[0]

        # Validate and process fetches.
        processed_fetches = self._process_fetches(fetches)
        unique_fetches = processed_fetches[0]
        target_list = processed_fetches[1]
        fetch_info = processed_fetches[2]
        unique_handles = processed_fetches[3]

        # Create request.
        feed_dict_string = {}
        feed_map = {}

        # Validate and process feed_dict.
        if feed_dict:
            feed_dict = nest.flatten_dict_items(feed_dict)
            for feed, feed_val in feed_dict.items():
                for subfeed, subfeed_val in _feed_fn(feed, feed_val):
                    try:
                        subfeed_t = self.graph.as_graph_element(
                            subfeed, allow_tensor=True, allow_operation=False)
                    except Exception as e:
                        raise TypeError(
                            'Cannot interpret feed_dict key as Tensor: ' +
                            e.args[0])

                    if isinstance(subfeed_val, ops.Tensor):
                        raise TypeError(
                            'The value of a feed cannot be a tf.Tensor object. '
                            'Acceptable feed values include Python scalars, '
                            'strings, lists, or numpy ndarrays.')

                    subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
                    if isinstance(
                            subfeed_val,
                            int) and subfeed_dtype(subfeed_val) != subfeed_val:
                        raise TypeError(
                            'Type of feed value ' + str(subfeed_val) +
                            ' is not'
                            ' compatible with Tensor type ' +
                            str(subfeed_dtype) + '.'
                            ' Try explicitly setting the type of the feed tensor'
                            ' to a larger type (e.g. int64).')

                    np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)

                    if not subfeed_t.get_shape().is_compatible_with(
                            np_val.shape):
                        raise ValueError(
                            'Cannot feed value of shape %r for Tensor %r, '
                            'which has shape %r' %
                            (np_val.shape, subfeed_t.name,
                             str(subfeed_t.get_shape())))
                    if not self.graph.is_feedable(subfeed_t):
                        raise ValueError('Tensor %s may not be fed.' %
                                         subfeed_t)
                    subfeed_name = compat.as_bytes(subfeed_t.name)
                    feed_dict_string[subfeed_name] = np_val
                    feed_map[subfeed_name] = (subfeed_t, subfeed_val)

        # Run request and get response.
        # We need to keep the movers alive for the following _do_run().
        # These movers are no longer needed when _do_run() completes, and
        # are deleted when `movers` goes out of scope when this _run() ends.
        # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding
        # of a handle from a different device as an error.
        movers = self._update_with_movers(feed_dict_string, feed_map)
        results = self._do_run(handle, target_list, unique_fetches,
                               feed_dict_string, options, run_metadata)

        # User may have fetched the same tensor multiple times, but we
        # only fetch them from the runtime once.  Furthermore, they may
        # be wrapped as a tuple of tensors.  Here we map the results back
        # to what the client asked for.
        # TODO(yuanbyu): Use the contraction_fn in _REGISTERED_EXPANSIONS.
        fetched_results = {}
        for fetch, result in zip(unique_fetches, results):
            dtype = unique_handles.get(fetch)
            if dtype:
                result = session_ops.TensorHandle(result, dtype, self)
            fetched_results[fetch] = result
        ret = []
        for fetch_names, fetch_contraction_fn in fetch_info:
            if fetch_names:
                fetched_vals = [fetched_results[name] for name in fetch_names]
                ret.append(fetch_contraction_fn(fetched_vals))
            else:
                ret.append(None)

        return unflatten(ret)
示例#3
0
  def _run(self, handle, fetches, feed_dict, options, run_metadata):
    """Perform either run or partial_run, depending the exitence of `handle`."""
    def _feed_fn(feed, feed_val):
      for tensor_type, _, feed_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
        if isinstance(feed, tensor_type):
          return feed_fn(feed, feed_val)
      raise TypeError('Feed argument %r has invalid type %r'
                      % (feed, type(feed)))

    # Check session.
    if self._closed:
      raise RuntimeError('Attempted to use a closed Session.')
    if self.graph.version == 0:
      raise RuntimeError('The Session graph is empty.  Add operations to the '
                         'graph before calling run().')

    # Validate and process fetches.
    processed_fetches = self._process_fetches(fetches)
    unique_fetches = processed_fetches[0]
    target_list = processed_fetches[1]
    fetch_info = processed_fetches[2]
    unique_handles = processed_fetches[3]

    # Create request.
    feed_dict_string = {}
    feed_map = {}

    # Validate and process feed_dict.
    if feed_dict:
      for feed, feed_val in feed_dict.items():
        for subfeed, subfeed_val in _feed_fn(feed, feed_val):
          try:
            subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
                                                    allow_operation=False)
          except Exception as e:
            raise TypeError('Cannot interpret feed_dict key as Tensor: '
                            + e.args[0])

          if isinstance(subfeed_val, ops.Tensor):
            raise TypeError('The value of a feed cannot be a tf.Tensor object. '
                            'Acceptable feed values include Python scalars, '
                            'strings, lists, or numpy ndarrays.')

          subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
          if isinstance(subfeed_val,
                        int) and subfeed_dtype(subfeed_val) != subfeed_val:
            raise TypeError(
                'Type of feed value ' + str(subfeed_val) + ' is not'
                ' compatible with Tensor type ' + str(subfeed_dtype) + '.'
                ' Try explicitly setting the type of the feed tensor'
                ' to a larger type (e.g. int64).')

          np_val = np.array(subfeed_val, dtype=subfeed_dtype)

          if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
            raise ValueError(
                'Cannot feed value of shape %r for Tensor %r, '
                'which has shape %r'
                % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
          if not self.graph.is_feedable(subfeed_t):
            raise ValueError('Tensor %s may not be fed.' % subfeed_t)
          subfeed_name = compat.as_bytes(subfeed_t.name)
          feed_dict_string[subfeed_name] = np_val
          feed_map[subfeed_name] = (subfeed_t, subfeed_val)

    # Run request and get response.
    movers = self._update_with_movers(feed_dict_string, feed_map)
    try:
      results = self._do_run(handle, target_list, unique_fetches,
                             feed_dict_string, options, run_metadata)
    finally:
      # The movers are no longer used. Delete them.
      for handle in movers:
        self._register_dead_handle(handle)

    # User may have fetched the same tensor multiple times, but we
    # only fetch them from the runtime once.  Furthermore, they may
    # be wrapped as a tuple of tensors.  Here we map the results back
    # to what the client asked for.
    # TODO(yuanbyu): Use the contraction_fn in _REGISTERED_EXPANSIONS.
    fetched_results = {}
    for fetch, result in zip(unique_fetches, results):
      dtype = unique_handles.get(fetch)
      if dtype:
        result = session_ops.TensorHandle(result, dtype, self)
      fetched_results[fetch] = result
    ret = []
    for fetch_names, fetch_contraction_fn in fetch_info:
      if fetch_names:
        fetched_vals = [fetched_results[name] for name in fetch_names]
        ret.append(fetch_contraction_fn(fetched_vals))
      else:
        ret.append(None)

    if isinstance(fetches, (list, tuple)):
      return ret
    else:
      return ret[0]