Example #1
0
def get_traverse_shallow_structure(traverse_fn, structure):
  """Generates a shallow structure from a `traverse_fn` and `structure`.

  `traverse_fn` must accept any possible subtree of `structure` and return
  a depth=1 structure containing `True` or `False` values, describing which
  of the top-level subtrees may be traversed.  It may also
  return scalar `True` or `False` 'traversal is OK / not OK for all subtrees.'

  Examples are available in the unit tests (nest_test.py).

  Args:
    traverse_fn: Function taking a substructure and returning either a scalar
      `bool` (whether to traverse that substructure or not) or a depth=1
      shallow structure of the same type, describing which parts of the
      substructure to traverse.
    structure: The structure to traverse.

  Returns:
    A shallow structure containing python bools, which can be passed to
    `map_up_to` and `flatten_up_to`.

  Raises:
    TypeError: if `traverse_fn` returns a sequence for a non-sequence input,
      or a structure with depth higher than 1 for a sequence input,
      or if any leaf values in the returned structure or scalar are not type
      `bool`.
  """
  to_traverse = traverse_fn(structure)
  if not is_nested(structure):
    if not isinstance(to_traverse, bool):
      raise TypeError('traverse_fn returned structure: %s for non-structure: %s'
                      % (to_traverse, structure))
    return to_traverse
  level_traverse = []
  if isinstance(to_traverse, bool):
    if not to_traverse:
      # Do not traverse this substructure at all.  Exit early.
      return False
    else:
      # Traverse the entire substructure.
      for branch in _yield_value(structure):
        level_traverse.append(
            get_traverse_shallow_structure(traverse_fn, branch))
  elif not is_nested(to_traverse):
    raise TypeError('traverse_fn returned a non-bool scalar: %s for input: %s'
                    % (to_traverse, structure))
  else:
    # Traverse some subset of this substructure.
    assert_shallow_structure(to_traverse, structure)
    for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)):
      if not isinstance(t, bool):
        raise TypeError(
            'traverse_fn didn\'t return a depth=1 structure of bools.  saw: %s '
            ' for structure: %s' % (to_traverse, structure))
      if t:
        level_traverse.append(
            get_traverse_shallow_structure(traverse_fn, branch))
      else:
        level_traverse.append(False)
  return _sequence_like(structure, level_traverse)
Example #2
0
def apply_to_structure(branch_fn, leaf_fn, structure):
  """`apply_to_structure` applies branch_fn and leaf_fn to branches and leaves.

  This function accepts two separate callables depending on whether the
  structure is a sequence.

  Args:
    branch_fn: A function to call on a struct if is_nested(struct) is `True`.
    leaf_fn: A function to call on a struct if is_nested(struct) is `False`.
    structure: A nested structure containing arguments to be applied to.

  Returns:
    A nested structure of function outputs.

  Raises:
    TypeError: If `branch_fn` or `leaf_fn` is not callable.
    ValueError: If no structure is provided.
  """
  if not callable(leaf_fn):
    raise TypeError('leaf_fn must be callable, got: %s' % leaf_fn)

  if not callable(branch_fn):
    raise TypeError('branch_fn must be callable, got: %s' % branch_fn)

  if not is_nested(structure):
    return leaf_fn(structure)

  processed = branch_fn(structure)

  new_structure = [
      apply_to_structure(branch_fn, leaf_fn, value)
      for value in _yield_value(processed)
  ]
  return _sequence_like(processed, new_structure)
Example #3
0
 def mirror_structure(value, reference_tree):
   if tree.is_nested(value):
     # Use check_types=True so that the types of the trees we construct aren't
     # dependent on our arbitrary choice of which nested arg to use as the
     # reference_tree.
     tree.assert_same_structure(value, reference_tree, check_types=True)
     return value
   else:
     return tree.map_structure(lambda _: value, reference_tree)
Example #4
0
def broadcast_structures(*args: Any) -> Any:
    """Returns versions of the arguments that give them the same nested structure.

  Any nested items in *args must have the same structure.

  Any non-nested item will be replaced with a nested version that shares that
  structure. The leaves will all be references to the same original non-nested
  item.

  If all *args are nested, or all *args are non-nested, this function will
  return *args unchanged.

  Example:
  ```
  a = ('a', 'b')
  b = 'c'
  tree_a, tree_b = broadcast_structure(a, b)
  tree_a
  > ('a', 'b')
  tree_b
  > ('c', 'c')
  ```

  Args:
    *args: A Sequence of nested or non-nested items.

  Returns:
    `*args`, except with all items sharing the same nest structure.
  """
    if not args:
        return

    reference_tree = None
    for arg in args:
        if tree.is_nested(arg):
            reference_tree = arg
            break

    # If reference_tree is None then none of args are nested and we can skip over
    # the rest of this function, which would be a no-op.
    if reference_tree is None:
        return args

    def mirror_structure(value, reference_tree):
        if tree.is_nested(value):
            # Use check_types=True so that the types of the trees we construct aren't
            # dependent on our arbitrary choice of which nested arg to use as the
            # reference_tree.
            tree.assert_same_structure(value, reference_tree, check_types=True)
            return value
        else:
            return tree.map_structure(lambda _: value, reference_tree)

    return tuple(mirror_structure(arg, reference_tree) for arg in args)
Example #5
0
def flatten_dict_items(dictionary):
  """Returns a dictionary with flattened keys and values.

  This function flattens the keys and values of a dictionary, which can be
  arbitrarily nested structures, and returns the flattened version of such
  structures:

  >>> example_dictionary = {(4, 5, (6, 8)): ('a', 'b', ('c', 'd'))}
  >>> result = {4: 'a', 5: 'b', 6: 'c', 8: 'd'}
  >>> assert tree.flatten_dict_items(example_dictionary) == result

  The input dictionary must satisfy two properties:

  1. Its keys and values should have the same exact nested structure.
  2. The set of all flattened keys of the dictionary must not contain repeated
     keys.

  Args:
    dictionary: the dictionary to zip

  Returns:
    The zipped dictionary.

  Raises:
    TypeError: If the input is not a dictionary.
    ValueError: If any key and value do not have the same structure layout, or
      if keys are not unique.
  """
  if not isinstance(dictionary, (dict, collections.Mapping)):
    raise TypeError('input must be a dictionary')

  flat_dictionary = {}
  for i, v in dictionary.items():
    if not is_nested(i):
      if i in flat_dictionary:
        raise ValueError(
            'Could not flatten dictionary: key %s is not unique.' % i)
      flat_dictionary[i] = v
    else:
      flat_i = flatten(i)
      flat_v = flatten(v)
      if len(flat_i) != len(flat_v):
        raise ValueError(
            'Could not flatten dictionary. Key had %d elements, but value had '
            '%d elements. Key: %s, value: %s.'
            % (len(flat_i), len(flat_v), flat_i, flat_v))
      for new_i, new_v in zip(flat_i, flat_v):
        if new_i in flat_dictionary:
          raise ValueError(
              'Could not flatten dictionary: key %s is not unique.'
              % (new_i,))
        flat_dictionary[new_i] = new_v
  return flat_dictionary
Example #6
0
    def add(
        self,
        actions: Dict[str, types.NestedArray],
        next_timestep: dm_env.TimeStep,
        next_extras: Dict[str, types.NestedArray] = {},
    ) -> None:
        """Record an action and the following timestep."""
        if self._next_observations is None:
            raise ValueError(
                "adder.add_first must be called before adder.add.")

        discount = next_timestep.discount
        if next_timestep.last():
            # Terminal timesteps created by dm_env.termination() will have a scalar
            # discount of 0.0. This may not match the array shape / nested structure
            # of the previous timesteps' discounts. The below will match
            # next_timestep.discount's shape/structure to that of
            # self._buffer[-1].discount.
            if self._buffer and not tree.is_nested(next_timestep.discount):
                discount = tree.map_structure(
                    lambda d: np.broadcast_to(next_timestep.discount,
                                              np.shape(d)),
                    self._buffer[-1].discount,
                )

        self._buffer.append(
            Step(
                observations=self._next_observations,
                actions=actions,
                rewards=next_timestep.reward,
                discounts=discount,
                start_of_episode=self._start_of_episode,
                extras=self._next_extras
                if self._use_next_extras else next_extras,
            ))

        # Write the last "dangling" observation.
        if next_timestep.last():
            self._start_of_episode = False
            self._write()
            self._write_last()
            self.reset()
        else:
            # Record the next observation and write.
            # Possibly store next_extras
            if self._use_next_extras:
                self._next_extras = next_extras
            self._next_observations = next_timestep.observation
            self._start_of_episode = False
            self._write()
Example #7
0
File: base.py Project: wzyxwqx/acme
    def add(self,
            action: types.NestedArray,
            next_timestep: dm_env.TimeStep,
            extras: types.NestedArray = ()):
        """Record an action and the following timestep."""
        if self._next_observation is None:
            raise ValueError(
                'adder.add_first must be called before adder.add.')

        discount = next_timestep.discount
        if next_timestep.last():
            # Terminal timesteps created by dm_env.termination() will have a scalar
            # discount of 0.0. This may not match the array shape / nested structure
            # of the previous timesteps' discounts. The below will match
            # next_timestep.discount's shape/structure to that of
            # self._buffer[-1].discount.
            if self._buffer and not tree.is_nested(next_timestep.discount):
                discount = tree.map_structure(
                    lambda d: np.broadcast_to(next_timestep.discount,
                                              np.shape(d)),
                    self._buffer[-1].discount)

        # Add the timestep to the buffer.
        self._buffer.append(
            Step(
                observation=self._next_observation,
                action=action,
                reward=next_timestep.reward,
                discount=discount,
                start_of_episode=self._start_of_episode,
                extras=extras,
            ))

        # Record the next observation and write.
        self._next_observation = next_timestep.observation
        self._start_of_episode = False
        self._write()

        # Write the last "dangling" observation.
        if next_timestep.last():
            self._write_last()
            self.reset()
Example #8
0
def crop_and_resize(images,
                    bboxes,
                    target_size,
                    methods=tf.image.ResizeMethod.BILINEAR,
                    extrapolation_value=0):
    """Does crop and resize given normalized boxes."""
    bboxes = tf.cast(bboxes, tf.float32)
    if not isinstance(target_size, (tuple, list)):
        target_size = [target_size, target_size]

    def resize_fn(images, resize_method, bboxes):
        """Resizes images according to bboxes and resize method."""
        squeeze = False
        if images.shape.rank == 3:
            images = images[None, Ellipsis]
            bboxes = bboxes[None, Ellipsis]
            squeeze = True
        if callable(resize_method):
            r = resize_method(images,
                              boxes=bboxes,
                              crop_size=target_size,
                              extrapolation_value=extrapolation_value)
        else:
            r = tf.image.crop_and_resize(
                images,
                boxes=bboxes,
                method=_resize_methods[resize_method],
                crop_size=target_size,
                box_ind=tf.range(tf.shape(images)[0]),
                extrapolation_value=extrapolation_value)
        if squeeze:
            r = r[0]
        return r

    args = [images]
    if tree.is_nested(methods):
        args.append(methods)
    else:
        resize_fn = functools.partial(resize_fn, resize_method=methods)
    resize_fn = functools.partial(resize_fn, bboxes=bboxes)
    return tree.map_structure(resize_fn, *args)
Example #9
0
def is_sequence(structure):
  return is_nested(structure)
Example #10
0
 def testIsSequence(self):
     self.assertFalse(tree.is_nested("1234"))
     self.assertFalse(tree.is_nested(b"1234"))
     self.assertFalse(tree.is_nested(u"1234"))
     self.assertFalse(tree.is_nested(bytearray("1234", "ascii")))
     self.assertTrue(tree.is_nested([1, 3, [4, 5]]))
     self.assertTrue(tree.is_nested(((7, 8), (5, 6))))
     self.assertTrue(tree.is_nested([]))
     self.assertTrue(tree.is_nested({"a": 1, "b": 2}))
     self.assertFalse(tree.is_nested(set([1, 2])))
     ones = np.ones([2, 3])
     self.assertFalse(tree.is_nested(ones))
     self.assertFalse(tree.is_nested(np.tanh(ones)))
     self.assertFalse(tree.is_nested(np.ones((4, 5))))