def get_traverse_shallow_structure(traverse_fn, structure): """Generates a shallow structure from a `traverse_fn` and `structure`. `traverse_fn` must accept any possible subtree of `structure` and return a depth=1 structure containing `True` or `False` values, describing which of the top-level subtrees may be traversed. It may also return scalar `True` or `False` 'traversal is OK / not OK for all subtrees.' Examples are available in the unit tests (nest_test.py). Args: traverse_fn: Function taking a substructure and returning either a scalar `bool` (whether to traverse that substructure or not) or a depth=1 shallow structure of the same type, describing which parts of the substructure to traverse. structure: The structure to traverse. Returns: A shallow structure containing python bools, which can be passed to `map_up_to` and `flatten_up_to`. Raises: TypeError: if `traverse_fn` returns a sequence for a non-sequence input, or a structure with depth higher than 1 for a sequence input, or if any leaf values in the returned structure or scalar are not type `bool`. """ to_traverse = traverse_fn(structure) if not is_nested(structure): if not isinstance(to_traverse, bool): raise TypeError('traverse_fn returned structure: %s for non-structure: %s' % (to_traverse, structure)) return to_traverse level_traverse = [] if isinstance(to_traverse, bool): if not to_traverse: # Do not traverse this substructure at all. Exit early. return False else: # Traverse the entire substructure. for branch in _yield_value(structure): level_traverse.append( get_traverse_shallow_structure(traverse_fn, branch)) elif not is_nested(to_traverse): raise TypeError('traverse_fn returned a non-bool scalar: %s for input: %s' % (to_traverse, structure)) else: # Traverse some subset of this substructure. assert_shallow_structure(to_traverse, structure) for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)): if not isinstance(t, bool): raise TypeError( 'traverse_fn didn\'t return a depth=1 structure of bools. saw: %s ' ' for structure: %s' % (to_traverse, structure)) if t: level_traverse.append( get_traverse_shallow_structure(traverse_fn, branch)) else: level_traverse.append(False) return _sequence_like(structure, level_traverse)
def apply_to_structure(branch_fn, leaf_fn, structure): """`apply_to_structure` applies branch_fn and leaf_fn to branches and leaves. This function accepts two separate callables depending on whether the structure is a sequence. Args: branch_fn: A function to call on a struct if is_nested(struct) is `True`. leaf_fn: A function to call on a struct if is_nested(struct) is `False`. structure: A nested structure containing arguments to be applied to. Returns: A nested structure of function outputs. Raises: TypeError: If `branch_fn` or `leaf_fn` is not callable. ValueError: If no structure is provided. """ if not callable(leaf_fn): raise TypeError('leaf_fn must be callable, got: %s' % leaf_fn) if not callable(branch_fn): raise TypeError('branch_fn must be callable, got: %s' % branch_fn) if not is_nested(structure): return leaf_fn(structure) processed = branch_fn(structure) new_structure = [ apply_to_structure(branch_fn, leaf_fn, value) for value in _yield_value(processed) ] return _sequence_like(processed, new_structure)
def mirror_structure(value, reference_tree): if tree.is_nested(value): # Use check_types=True so that the types of the trees we construct aren't # dependent on our arbitrary choice of which nested arg to use as the # reference_tree. tree.assert_same_structure(value, reference_tree, check_types=True) return value else: return tree.map_structure(lambda _: value, reference_tree)
def broadcast_structures(*args: Any) -> Any: """Returns versions of the arguments that give them the same nested structure. Any nested items in *args must have the same structure. Any non-nested item will be replaced with a nested version that shares that structure. The leaves will all be references to the same original non-nested item. If all *args are nested, or all *args are non-nested, this function will return *args unchanged. Example: ``` a = ('a', 'b') b = 'c' tree_a, tree_b = broadcast_structure(a, b) tree_a > ('a', 'b') tree_b > ('c', 'c') ``` Args: *args: A Sequence of nested or non-nested items. Returns: `*args`, except with all items sharing the same nest structure. """ if not args: return reference_tree = None for arg in args: if tree.is_nested(arg): reference_tree = arg break # If reference_tree is None then none of args are nested and we can skip over # the rest of this function, which would be a no-op. if reference_tree is None: return args def mirror_structure(value, reference_tree): if tree.is_nested(value): # Use check_types=True so that the types of the trees we construct aren't # dependent on our arbitrary choice of which nested arg to use as the # reference_tree. tree.assert_same_structure(value, reference_tree, check_types=True) return value else: return tree.map_structure(lambda _: value, reference_tree) return tuple(mirror_structure(arg, reference_tree) for arg in args)
def flatten_dict_items(dictionary): """Returns a dictionary with flattened keys and values. This function flattens the keys and values of a dictionary, which can be arbitrarily nested structures, and returns the flattened version of such structures: >>> example_dictionary = {(4, 5, (6, 8)): ('a', 'b', ('c', 'd'))} >>> result = {4: 'a', 5: 'b', 6: 'c', 8: 'd'} >>> assert tree.flatten_dict_items(example_dictionary) == result The input dictionary must satisfy two properties: 1. Its keys and values should have the same exact nested structure. 2. The set of all flattened keys of the dictionary must not contain repeated keys. Args: dictionary: the dictionary to zip Returns: The zipped dictionary. Raises: TypeError: If the input is not a dictionary. ValueError: If any key and value do not have the same structure layout, or if keys are not unique. """ if not isinstance(dictionary, (dict, collections.Mapping)): raise TypeError('input must be a dictionary') flat_dictionary = {} for i, v in dictionary.items(): if not is_nested(i): if i in flat_dictionary: raise ValueError( 'Could not flatten dictionary: key %s is not unique.' % i) flat_dictionary[i] = v else: flat_i = flatten(i) flat_v = flatten(v) if len(flat_i) != len(flat_v): raise ValueError( 'Could not flatten dictionary. Key had %d elements, but value had ' '%d elements. Key: %s, value: %s.' % (len(flat_i), len(flat_v), flat_i, flat_v)) for new_i, new_v in zip(flat_i, flat_v): if new_i in flat_dictionary: raise ValueError( 'Could not flatten dictionary: key %s is not unique.' % (new_i,)) flat_dictionary[new_i] = new_v return flat_dictionary
def add( self, actions: Dict[str, types.NestedArray], next_timestep: dm_env.TimeStep, next_extras: Dict[str, types.NestedArray] = {}, ) -> None: """Record an action and the following timestep.""" if self._next_observations is None: raise ValueError( "adder.add_first must be called before adder.add.") discount = next_timestep.discount if next_timestep.last(): # Terminal timesteps created by dm_env.termination() will have a scalar # discount of 0.0. This may not match the array shape / nested structure # of the previous timesteps' discounts. The below will match # next_timestep.discount's shape/structure to that of # self._buffer[-1].discount. if self._buffer and not tree.is_nested(next_timestep.discount): discount = tree.map_structure( lambda d: np.broadcast_to(next_timestep.discount, np.shape(d)), self._buffer[-1].discount, ) self._buffer.append( Step( observations=self._next_observations, actions=actions, rewards=next_timestep.reward, discounts=discount, start_of_episode=self._start_of_episode, extras=self._next_extras if self._use_next_extras else next_extras, )) # Write the last "dangling" observation. if next_timestep.last(): self._start_of_episode = False self._write() self._write_last() self.reset() else: # Record the next observation and write. # Possibly store next_extras if self._use_next_extras: self._next_extras = next_extras self._next_observations = next_timestep.observation self._start_of_episode = False self._write()
def add(self, action: types.NestedArray, next_timestep: dm_env.TimeStep, extras: types.NestedArray = ()): """Record an action and the following timestep.""" if self._next_observation is None: raise ValueError( 'adder.add_first must be called before adder.add.') discount = next_timestep.discount if next_timestep.last(): # Terminal timesteps created by dm_env.termination() will have a scalar # discount of 0.0. This may not match the array shape / nested structure # of the previous timesteps' discounts. The below will match # next_timestep.discount's shape/structure to that of # self._buffer[-1].discount. if self._buffer and not tree.is_nested(next_timestep.discount): discount = tree.map_structure( lambda d: np.broadcast_to(next_timestep.discount, np.shape(d)), self._buffer[-1].discount) # Add the timestep to the buffer. self._buffer.append( Step( observation=self._next_observation, action=action, reward=next_timestep.reward, discount=discount, start_of_episode=self._start_of_episode, extras=extras, )) # Record the next observation and write. self._next_observation = next_timestep.observation self._start_of_episode = False self._write() # Write the last "dangling" observation. if next_timestep.last(): self._write_last() self.reset()
def crop_and_resize(images, bboxes, target_size, methods=tf.image.ResizeMethod.BILINEAR, extrapolation_value=0): """Does crop and resize given normalized boxes.""" bboxes = tf.cast(bboxes, tf.float32) if not isinstance(target_size, (tuple, list)): target_size = [target_size, target_size] def resize_fn(images, resize_method, bboxes): """Resizes images according to bboxes and resize method.""" squeeze = False if images.shape.rank == 3: images = images[None, Ellipsis] bboxes = bboxes[None, Ellipsis] squeeze = True if callable(resize_method): r = resize_method(images, boxes=bboxes, crop_size=target_size, extrapolation_value=extrapolation_value) else: r = tf.image.crop_and_resize( images, boxes=bboxes, method=_resize_methods[resize_method], crop_size=target_size, box_ind=tf.range(tf.shape(images)[0]), extrapolation_value=extrapolation_value) if squeeze: r = r[0] return r args = [images] if tree.is_nested(methods): args.append(methods) else: resize_fn = functools.partial(resize_fn, resize_method=methods) resize_fn = functools.partial(resize_fn, bboxes=bboxes) return tree.map_structure(resize_fn, *args)
def is_sequence(structure): return is_nested(structure)
def testIsSequence(self): self.assertFalse(tree.is_nested("1234")) self.assertFalse(tree.is_nested(b"1234")) self.assertFalse(tree.is_nested(u"1234")) self.assertFalse(tree.is_nested(bytearray("1234", "ascii"))) self.assertTrue(tree.is_nested([1, 3, [4, 5]])) self.assertTrue(tree.is_nested(((7, 8), (5, 6)))) self.assertTrue(tree.is_nested([])) self.assertTrue(tree.is_nested({"a": 1, "b": 2})) self.assertFalse(tree.is_nested(set([1, 2]))) ones = np.ones([2, 3]) self.assertFalse(tree.is_nested(ones)) self.assertFalse(tree.is_nested(np.tanh(ones))) self.assertFalse(tree.is_nested(np.ones((4, 5))))