Exemple #1
0
  def testFlattenDictItems(self):
    dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
    flat = {4: "a", 5: "b", 6: "c", 8: "d"}
    self.assertEqual(nest.flatten_dict_items(dictionary), flat)

    with self.assertRaises(TypeError):
      nest.flatten_dict_items(4)

    bad_dictionary = {(4, 5, (4, 8)): ("a", "b", ("c", "d"))}
    with self.assertRaisesRegexp(ValueError, "not unique"):
      nest.flatten_dict_items(bad_dictionary)

    another_bad_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))}
    with self.assertRaisesRegexp(
        ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
      nest.flatten_dict_items(another_bad_dictionary)
Exemple #2
0
    def act(self, obs, policy_state, exploration=None):
        """ Return (actions, next policy state) given an observation and the current policy state. """
        self.maybe_build_mode()
        self.maybe_build_act()

        sess = tf.get_default_session()
        feed_dict = flatten_dict_items({self._policy_state: policy_state})
        feed_dict.update({self._obs: obs})
        if exploration is not None:
            feed_dict.update({self.exploration: exploration})

        log_probs, actions, entropy, utils, next_policy_state = sess.run(
            [
                self._log_probs, self._samples, self._entropy, self._utils,
                self._next_policy_state
            ],
            feed_dict=feed_dict)

        return (log_probs, actions, entropy, utils), next_policy_state
Exemple #3
0
  def _run(self, handle, fetches, feed_dict, options, run_metadata):
    """Perform either run or partial_run, depending the presence of `handle`."""
    def _feed_fn(feed, feed_val):
      for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS:
        if isinstance(feed, tensor_type):
          return feed_fn(feed, feed_val)
      raise TypeError('Feed argument %r has invalid type %r'
                      % (feed, type(feed)))

    # Check session.
    if self._closed:
      raise RuntimeError('Attempted to use a closed Session.')
    if self.graph.version == 0:
      raise RuntimeError('The Session graph is empty.  Add operations to the '
                         'graph before calling run().')

    # Create request.
    feed_dict_string = {}
    feed_map = {}

    # Validate and process feed_dict.
    if feed_dict:
      feed_dict = nest.flatten_dict_items(feed_dict)
      for feed, feed_val in feed_dict.items():
        for subfeed, subfeed_val in _feed_fn(feed, feed_val):
          try:
            subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
                                                    allow_operation=False)
          except Exception as e:
            raise TypeError('Cannot interpret feed_dict key as Tensor: '
                            + e.args[0])

          if isinstance(subfeed_val, ops.Tensor):
            raise TypeError('The value of a feed cannot be a tf.Tensor object. '
                            'Acceptable feed values include Python scalars, '
                            'strings, lists, or numpy ndarrays.')

          subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
          if isinstance(subfeed_val,
                        int) and subfeed_dtype(subfeed_val) != subfeed_val:
            raise TypeError(
                'Type of feed value ' + str(subfeed_val) + ' is not'
                ' compatible with Tensor type ' + str(subfeed_dtype) + '.'
                ' Try explicitly setting the type of the feed tensor'
                ' to a larger type (e.g. int64).')

          np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)

          if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
            raise ValueError(
                'Cannot feed value of shape %r for Tensor %r, '
                'which has shape %r'
                % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
          if not self.graph.is_feedable(subfeed_t):
            raise ValueError('Tensor %s may not be fed.' % subfeed_t)
          subfeed_name = compat.as_bytes(subfeed_t.name)
          feed_dict_string[subfeed_name] = np_val
          feed_map[subfeed_name] = (subfeed_t, subfeed_val)

    # Create a fetch handler to take care of the structure of fetches.
    fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)

    # Run request and get response.
    # We need to keep the movers alive for the following _do_run().
    # These movers are no longer needed when _do_run() completes, and
    # are deleted when `movers` goes out of scope when this _run() ends.
    # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding
    # of a handle from a different device as an error.
    movers = self._update_with_movers(feed_dict_string, feed_map)
    final_fetches = fetch_handler.fetches()
    final_targets = fetch_handler.targets()
    if final_fetches or final_targets:
      results = self._do_run(handle, final_targets, final_fetches,
                             feed_dict_string, options, run_metadata)
    else:
      results = []
    return fetch_handler.build_results(self, results)
Exemple #4
0
  def _run(self, handle, fetches, feed_dict, options, run_metadata):
    """Perform either run or partial_run, depending the presence of `handle`."""
    def _feed_fn(feed, feed_val):
      for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS:
        if isinstance(feed, tensor_type):
          return feed_fn(feed, feed_val)
      raise TypeError('Feed argument %r has invalid type %r'
                      % (feed, type(feed)))

    # Check session.
    if self._closed:
      raise RuntimeError('Attempted to use a closed Session.')
    if self.graph.version == 0:
      raise RuntimeError('The Session graph is empty.  Add operations to the '
                         'graph before calling run().')

    # Create request.
    feed_dict_string = {}
    feed_map = {}

    # Validate and process feed_dict.
    if feed_dict:
      feed_dict = nest.flatten_dict_items(feed_dict)
      for feed, feed_val in feed_dict.items():
        for subfeed, subfeed_val in _feed_fn(feed, feed_val):
          try:
            subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
                                                    allow_operation=False)
          except Exception as e:
            raise TypeError('Cannot interpret feed_dict key as Tensor: '
                            + e.args[0])

          if isinstance(subfeed_val, ops.Tensor):
            raise TypeError('The value of a feed cannot be a tf.Tensor object. '
                            'Acceptable feed values include Python scalars, '
                            'strings, lists, or numpy ndarrays.')

          subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
          if isinstance(subfeed_val,
                        int) and subfeed_dtype(subfeed_val) != subfeed_val:
            raise TypeError(
                'Type of feed value ' + str(subfeed_val) + ' is not'
                ' compatible with Tensor type ' + str(subfeed_dtype) + '.'
                ' Try explicitly setting the type of the feed tensor'
                ' to a larger type (e.g. int64).')

          np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)

          if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
            raise ValueError(
                'Cannot feed value of shape %r for Tensor %r, '
                'which has shape %r'
                % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
          if not self.graph.is_feedable(subfeed_t):
            raise ValueError('Tensor %s may not be fed.' % subfeed_t)
          subfeed_name = compat.as_bytes(subfeed_t.name)
          feed_dict_string[subfeed_name] = np_val
          feed_map[subfeed_name] = (subfeed_t, subfeed_val)

    # Create a fetch handler to take care of the structure of fetches.
    fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)

    # Run request and get response.
    # We need to keep the movers alive for the following _do_run().
    # These movers are no longer needed when _do_run() completes, and
    # are deleted when `movers` goes out of scope when this _run() ends.
    # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding
    # of a handle from a different device as an error.
    movers = self._update_with_movers(feed_dict_string, feed_map)
    final_fetches = fetch_handler.fetches()
    final_targets = fetch_handler.targets()
    if final_fetches or final_targets:
      results = self._do_run(handle, final_targets, final_fetches,
                             feed_dict_string, options, run_metadata)
    else:
      results = []
    return fetch_handler.build_results(self, results)
Exemple #5
0
    def _run(self, handle, fetches, feed_dict, options, run_metadata):
        """Perform either run or partial_run, depending the exitence of `handle`."""
        def _feed_fn(feed, feed_val):
            for tensor_type, _, feed_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
                if isinstance(feed, tensor_type):
                    return feed_fn(feed, feed_val)
            raise TypeError('Feed argument %r has invalid type %r' %
                            (feed, type(feed)))

        # Check session.
        if self._closed:
            raise RuntimeError('Attempted to use a closed Session.')
        if self.graph.version == 0:
            raise RuntimeError(
                'The Session graph is empty.  Add operations to the '
                'graph before calling run().')

        # Flatten/unflatten fetched values.
        if isinstance(fetches, (list, tuple)):
            # fetches is already a list or tuple; nothing to do.
            orig_fetches, fetches = fetches, nest.flatten(fetches)
            unflatten = lambda fetched: nest.pack_sequence_as(
                orig_fetches, fetched)
        elif isinstance(fetches, dict):
            # fetches is a dictionary; flatten the values and map fetched
            # values back into to a dictionary.
            # nest.flatten does not accept iterators, next line is for python3
            # compatibility.
            fetches_values = list(fetches.values())
            orig_fetches, fetches = fetches, nest.flatten(fetches_values)
            unflatten = lambda fetched: _unflatten_fetches(
                orig_fetches, fetched)
        else:
            # fetches is a singleton.
            fetches = [fetches]
            unflatten = lambda fetched: fetched[0]

        # Validate and process fetches.
        processed_fetches = self._process_fetches(fetches)
        unique_fetches = processed_fetches[0]
        target_list = processed_fetches[1]
        fetch_info = processed_fetches[2]
        unique_handles = processed_fetches[3]

        # Create request.
        feed_dict_string = {}
        feed_map = {}

        # Validate and process feed_dict.
        if feed_dict:
            feed_dict = nest.flatten_dict_items(feed_dict)
            for feed, feed_val in feed_dict.items():
                for subfeed, subfeed_val in _feed_fn(feed, feed_val):
                    try:
                        subfeed_t = self.graph.as_graph_element(
                            subfeed, allow_tensor=True, allow_operation=False)
                    except Exception as e:
                        raise TypeError(
                            'Cannot interpret feed_dict key as Tensor: ' +
                            e.args[0])

                    if isinstance(subfeed_val, ops.Tensor):
                        raise TypeError(
                            'The value of a feed cannot be a tf.Tensor object. '
                            'Acceptable feed values include Python scalars, '
                            'strings, lists, or numpy ndarrays.')

                    subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
                    if isinstance(
                            subfeed_val,
                            int) and subfeed_dtype(subfeed_val) != subfeed_val:
                        raise TypeError(
                            'Type of feed value ' + str(subfeed_val) +
                            ' is not'
                            ' compatible with Tensor type ' +
                            str(subfeed_dtype) + '.'
                            ' Try explicitly setting the type of the feed tensor'
                            ' to a larger type (e.g. int64).')

                    np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)

                    if not subfeed_t.get_shape().is_compatible_with(
                            np_val.shape):
                        raise ValueError(
                            'Cannot feed value of shape %r for Tensor %r, '
                            'which has shape %r' %
                            (np_val.shape, subfeed_t.name,
                             str(subfeed_t.get_shape())))
                    if not self.graph.is_feedable(subfeed_t):
                        raise ValueError('Tensor %s may not be fed.' %
                                         subfeed_t)
                    subfeed_name = compat.as_bytes(subfeed_t.name)
                    feed_dict_string[subfeed_name] = np_val
                    feed_map[subfeed_name] = (subfeed_t, subfeed_val)

        # Run request and get response.
        # We need to keep the movers alive for the following _do_run().
        # These movers are no longer needed when _do_run() completes, and
        # are deleted when `movers` goes out of scope when this _run() ends.
        # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding
        # of a handle from a different device as an error.
        movers = self._update_with_movers(feed_dict_string, feed_map)
        results = self._do_run(handle, target_list, unique_fetches,
                               feed_dict_string, options, run_metadata)

        # User may have fetched the same tensor multiple times, but we
        # only fetch them from the runtime once.  Furthermore, they may
        # be wrapped as a tuple of tensors.  Here we map the results back
        # to what the client asked for.
        # TODO(yuanbyu): Use the contraction_fn in _REGISTERED_EXPANSIONS.
        fetched_results = {}
        for fetch, result in zip(unique_fetches, results):
            dtype = unique_handles.get(fetch)
            if dtype:
                result = session_ops.TensorHandle(result, dtype, self)
            fetched_results[fetch] = result
        ret = []
        for fetch_names, fetch_contraction_fn in fetch_info:
            if fetch_names:
                fetched_vals = [fetched_results[name] for name in fetch_names]
                ret.append(fetch_contraction_fn(fetched_vals))
            else:
                ret.append(None)

        return unflatten(ret)
Exemple #6
0
  def _run(self, handle, fetches, feed_dict, options, run_metadata):
    """Perform either run or partial_run, depending the exitence of `handle`."""
    def _feed_fn(feed, feed_val):
      for tensor_type, _, feed_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
        if isinstance(feed, tensor_type):
          return feed_fn(feed, feed_val)
      raise TypeError('Feed argument %r has invalid type %r'
                      % (feed, type(feed)))

    # Check session.
    if self._closed:
      raise RuntimeError('Attempted to use a closed Session.')
    if self.graph.version == 0:
      raise RuntimeError('The Session graph is empty.  Add operations to the '
                         'graph before calling run().')

    # Flatten/unflatten fetched values.
    if isinstance(fetches, (list, tuple)):
      # fetches is already a list or tuple; nothing to do.
      orig_fetches, fetches = fetches, nest.flatten(fetches)
      unflatten = lambda fetched: nest.pack_sequence_as(orig_fetches, fetched)
    elif isinstance(fetches, dict):
      # fetches is a dictionary; flatten the values and map fetched
      # values back into to a dictionary.
      # nest.flatten does not accept iterators, next line is for python3
      # compatibility.
      fetches_values = list(fetches.values())
      orig_fetches, fetches = fetches, nest.flatten(fetches_values)
      unflatten = lambda fetched: _unflatten_fetches(orig_fetches, fetched)
    else:
      # fetches is a singleton.
      fetches = [fetches]
      unflatten = lambda fetched: fetched[0]

    # Validate and process fetches.
    processed_fetches = self._process_fetches(fetches)
    unique_fetches = processed_fetches[0]
    target_list = processed_fetches[1]
    fetch_info = processed_fetches[2]
    unique_handles = processed_fetches[3]

    # Create request.
    feed_dict_string = {}
    feed_map = {}

    # Validate and process feed_dict.
    if feed_dict:
      feed_dict = nest.flatten_dict_items(feed_dict)
      for feed, feed_val in feed_dict.items():
        for subfeed, subfeed_val in _feed_fn(feed, feed_val):
          try:
            subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
                                                    allow_operation=False)
          except Exception as e:
            raise TypeError('Cannot interpret feed_dict key as Tensor: '
                            + e.args[0])

          if isinstance(subfeed_val, ops.Tensor):
            raise TypeError('The value of a feed cannot be a tf.Tensor object. '
                            'Acceptable feed values include Python scalars, '
                            'strings, lists, or numpy ndarrays.')

          subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
          if isinstance(subfeed_val,
                        int) and subfeed_dtype(subfeed_val) != subfeed_val:
            raise TypeError(
                'Type of feed value ' + str(subfeed_val) + ' is not'
                ' compatible with Tensor type ' + str(subfeed_dtype) + '.'
                ' Try explicitly setting the type of the feed tensor'
                ' to a larger type (e.g. int64).')

          np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)

          if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
            raise ValueError(
                'Cannot feed value of shape %r for Tensor %r, '
                'which has shape %r'
                % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
          if not self.graph.is_feedable(subfeed_t):
            raise ValueError('Tensor %s may not be fed.' % subfeed_t)
          subfeed_name = compat.as_bytes(subfeed_t.name)
          feed_dict_string[subfeed_name] = np_val
          feed_map[subfeed_name] = (subfeed_t, subfeed_val)

    # Run request and get response.
    # We need to keep the movers alive for the following _do_run().
    # These movers are no longer needed when _do_run() completes, and
    # are deleted when `movers` goes out of scope when this _run() ends.
    # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding
    # of a handle from a different device as an error.
    movers = self._update_with_movers(feed_dict_string, feed_map)
    results = self._do_run(handle, target_list, unique_fetches,
                           feed_dict_string, options, run_metadata)

    # User may have fetched the same tensor multiple times, but we
    # only fetch them from the runtime once.  Furthermore, they may
    # be wrapped as a tuple of tensors.  Here we map the results back
    # to what the client asked for.
    # TODO(yuanbyu): Use the contraction_fn in _REGISTERED_EXPANSIONS.
    fetched_results = {}
    for fetch, result in zip(unique_fetches, results):
      dtype = unique_handles.get(fetch)
      if dtype:
        result = session_ops.TensorHandle(result, dtype, self)
      fetched_results[fetch] = result
    ret = []
    for fetch_names, fetch_contraction_fn in fetch_info:
      if fetch_names:
        fetched_vals = [fetched_results[name] for name in fetch_names]
        ret.append(fetch_contraction_fn(fetched_vals))
      else:
        ret.append(None)

    return unflatten(ret)
Exemple #7
0
 def fn():
   nest.flatten_dict_items(nested)