def testFlattenAndUnflatten(self): structure = ((3, 4), 5, (6, 7, (9, 10), 8)) flat = ["a", "b", "c", "d", "e", "f", "g", "h"] self.assertEqual(tree.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) self.assertEqual(tree.unflatten_as(structure, flat), (("a", "b"), "c", ("d", "e", ("f", "g"), "h"))) point = collections.namedtuple("Point", ["x", "y"]) structure = (point(x=4, y=2), ((point(x=1, y=0), ), )) flat = [4, 2, 1, 0] self.assertEqual(tree.flatten(structure), flat) restructured_from_flat = tree.unflatten_as(structure, flat) self.assertEqual(restructured_from_flat, structure) self.assertEqual(restructured_from_flat[0].x, 4) self.assertEqual(restructured_from_flat[0].y, 2) self.assertEqual(restructured_from_flat[1][0][0].x, 1) self.assertEqual(restructured_from_flat[1][0][0].y, 0) self.assertEqual([5], tree.flatten(5)) self.assertEqual([np.array([5])], tree.flatten(np.array([5]))) self.assertEqual("a", tree.unflatten_as(5, ["a"])) self.assertEqual(np.array([5]), tree.unflatten_as("scalar", [np.array([5])])) with self.assertRaisesRegex(ValueError, "Structure is a scalar"): tree.unflatten_as("scalar", [4, 5]) with self.assertRaisesRegex(TypeError, "flat_sequence"): tree.unflatten_as([4, 5], "bad_sequence") with self.assertRaises(ValueError): tree.unflatten_as([5, 6, [7, 8]], ["a", "b", "c"])
def testFlattenAndUnflatten(self, test_type): self._init_testdata(test_type) self.assertEqual(self.dcls_flattened, tree.flatten(self.dcls_with_map)) self.assertEqual( self.dcls_with_map, tree.unflatten_as(self.dcls_with_map_inc_ints, self.dcls_flattened)) dataclass_in_seq = [34, self.dcls_with_map, [1, 2]] dataclass_in_seq_flat = [34] + self.dcls_flattened + [1, 2] self.assertEqual(dataclass_in_seq_flat, tree.flatten(dataclass_in_seq)) self.assertEqual(dataclass_in_seq, tree.unflatten_as(dataclass_in_seq, dataclass_in_seq_flat))
async def load(self, version: int, structure: Any) -> Any: """Returns the program state for the given `version`. Args: version: A integer representing the version of a saved program state. structure: The nested structure of the saved program state for the given `version` used to support serialization and deserailization of user-defined classes in the structure. Raises: ProgramStateManagerStateNotFoundError: If there is no program state for the given `version`. ProgramStateManagerStructureError: If `structure` does not match the value loaded for the given `version`. """ py_typecheck.check_type(version, int) path = self._get_path_for_version(version) if not await file_utils.exists(path): raise program_state_manager.ProgramStateManagerStateNotFoundError( f'No program state found for version: {version}') flattened_state = await file_utils.read_saved_model(path) try: program_state = tree.unflatten_as(structure, flattened_state) except ValueError as e: raise program_state_manager.ProgramStateManagerStructureError( f'The structure of type {type(structure)}:\n' f'{structure}\n' f'does not match the value of type {type(flattened_state)}:\n' f'{flattened_state}\n') from e logging.info('Program state loaded: %s', path) return program_state
def pack_sequence_as(structure, flat_sequence, **kwargs): expand_composites = kwargs.pop('expand_composites', False) if expand_composites and JAX_MODE: from jax import tree_util # pylint: disable=g-import-not-at-top return tree_util.tree_unflatten( tree_util.tree_structure(structure), flat_sequence) return dm_tree.unflatten_as(structure, flat_sequence)
def wrapped_method(*args, **kwargs): """A wrapped method around a TF-Hub module signature.""" inputs = _getcallargs(self._method_specs[method]["specs"], *args, **kwargs) nest.assert_same_structure(self._method_specs[method]["inputs"], inputs) flat_inputs = nest.flatten(inputs) flat_inputs = { str(k): v for k, v in zip(range(len(flat_inputs)), flat_inputs) } signature = "default" if method == "__call__" else method flat_outputs = self._module( flat_inputs, signature=signature, as_dict=True) flat_outputs = [v for _, v in sorted(flat_outputs.items())] output_spec = self._method_specs[method]["outputs"] if output_spec is None: if len(flat_outputs) != 1: raise ValueError( "Expected output containing a single tensor, found {}".format( flat_outputs)) outputs = flat_outputs[0] else: outputs = nest.unflatten_as(output_spec, flat_outputs) return outputs
def testUnflattenDictOrder(self): ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) plain = {"d": 0, "b": 0, "a": 0, "c": 0} seq = [0, 1, 2, 3] ordered_reconstruction = tree.unflatten_as(ordered, seq) plain_reconstruction = tree.unflatten_as(plain, seq) self.assertEqual( collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), ordered_reconstruction) self.assertEqual({ "d": 3, "b": 1, "a": 0, "c": 2 }, plain_reconstruction)
def _get_torch_exploration_action(self, action_distribution: ActionDistribution, explore: bool, timestep: Union[int, TensorType]): """Torch method to produce an epsilon exploration action. Args: action_distribution (ActionDistribution): The instantiated ActionDistribution object to work with when creating exploration actions. Returns: torch.Tensor: The exploration-action. """ q_values = action_distribution.inputs self.last_timestep = timestep exploit_action = action_distribution.deterministic_sample() batch_size = q_values.size()[0] action_logp = torch.zeros(batch_size, dtype=torch.float) # Explore. if explore: # Get the current epsilon. epsilon = self.epsilon_schedule(self.last_timestep) if isinstance(action_distribution, TorchMultiActionDistribution): exploit_action = tree.flatten(exploit_action) for i in range(batch_size): if random.random() < epsilon: # TODO: (bcahlit) Mask out actions random_action = tree.flatten( self.action_space.sample()) for j in range(len(exploit_action)): exploit_action[j][i] = torch.tensor( random_action[j]) exploit_action = tree.unflatten_as( action_distribution.action_space_struct, exploit_action) return exploit_action, action_logp else: # Mask out actions, whose Q-values are -inf, so that we don't # even consider them for exploration. random_valid_action_logits = torch.where( q_values <= FLOAT_MIN, torch.ones_like(q_values) * 0.0, torch.ones_like(q_values)) # A random action. random_actions = torch.squeeze(torch.multinomial( random_valid_action_logits, 1), axis=1) # Pick either random or greedy. action = torch.where( torch.empty( (batch_size, )).uniform_().to(self.device) < epsilon, random_actions, exploit_action) return action, action_logp # Return the deterministic "sample" (argmax) over the logits. else: return exploit_action, action_logp
def sample(self, table: str, data_dtypes, name: Optional[str] = None) -> replay_sample.ReplaySample: """Samples an item from the replay. This only allows sampling items with a data field. Args: table: Probability table to sample from. data_dtypes: Dtypes of the data output. Can be nested. name: Optional name for the Client operations. Returns: A ReplaySample with data nested according to data_dtypes. See ReplaySample for more details. """ with tf.name_scope(name, f'{self._name}_sample', ['sample']) as scope: key, probability, table_size, priority, data = gen_client_ops.reverb_client_sample( self._handle, table, tree.flatten(data_dtypes), name=scope) return replay_sample.ReplaySample( replay_sample.SampleInfo(key=key, probability=probability, table_size=table_size, priority=priority), tree.unflatten_as(data_dtypes, data))
def create_reference_step( step_structure: tree.Structure[Any]) -> ReferenceStep: """Create a reference structure that can be used to build patterns. ```python step_structure = { 'a': None, 'b': { 'c': None, 'd': None, } } ref_step = create_reference_step(step_structure) pattern = { 'last_two_a': ref_step['a'][-2:] 'second_to_last_c': ref['b']['c'][-2] 'most_recent_d': ref['b']['d'][-1] } ``` Args: step_structure: Structure of the data which will be passed to `StructuredWriter.append`. Returns: An object with the same structure as `step_structure` except leaf nodes have been replaced with a helper object that builds `patterns_pb2.PatternNode` objects when __getitem__ is called. """ return tree.unflatten_as( step_structure, [_RefNode(x) for x in range(len(tree.flatten(step_structure)))])
def fast_map_structure(func, *structure): """Faster map_structure implementation which skips some error checking.""" flat_structure = (tree.flatten(s) for s in structure) entries = zip(*flat_structure) # Arbitrarily choose one of the structures of the original sequence (the last) # to match the structure for the flattened sequence. return tree.unflatten_as(structure[-1], [func(*x) for x in entries])
def data(step_structure: tree.Structure[Any]): """Value of a scalar integer or bool in the source data.""" flat = [ _ConditionBuilder(patterns_pb2.Condition(flat_source_index=i)) for i in range(len(tree.flatten(step_structure))) ] return tree.unflatten_as(step_structure, flat)
def deterministic_sample(self): child_distributions = tree.unflatten_as( self.action_space_struct, self.flat_child_distributions ) return tree.map_structure( lambda s: s.deterministic_sample(), child_distributions )
def unbatch(batches_struct): """Converts input from (nested) struct of batches to batch of structs. Input: Struct of different batches (each batch has size=3): {"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])} Output: Batch (list) of structs (each of these structs representing a single action): [ {"a": 1, "b": (4, 7.0)}, <- action 1 {"a": 2, "b": (5, 8.0)}, <- action 2 {"a": 3, "b": (6, 9.0)}, <- action 3 ] Args: batches_struct (any): The struct of component batches. Each leaf item in this struct represents the batch for a single component (in case struct is tuple/dict). Alternatively, `batches_struct` may also simply be a batch of primitives (non tuple/dict). Returns: List[struct[components]]: The list of rows. Each item in the returned list represents a single (maybe complex) struct. """ flat_batches = tree.flatten(batches_struct) out = [] for batch_pos in range(len(flat_batches[0])): out.append( tree.unflatten_as( batches_struct, [flat_batches[i][batch_pos] for i in range(len(flat_batches))])) return out
def _update_structure(self, new_structure: Any): """Replace the existing structure with a superset of the current one. Since the structure is allowed to evolve over time we are unable to simply map flattened data to column indices. For example, if the first step is `{'a': 1, 'c': 101}` and the second step is `{'a': 2, 'b': 12, 'c': 102}` then the flatten data would be `[1, 101]` and `[2, 12, 102]`. This will result in invalid behaviour as the second column (index 1) would receive `c` in the first step and `b` in the second. To mitigate this, `_structure` represents an explicit mapping of fields to column number containing data of a given field. The mapping is allowed to grow over time and would in the above example be `{'a': 0, 'c': 1}` and `{'a': 0, 'b': 2, 'c': 1}` after the first and second step resp. Data would thus be flatten as `[1, 101]` and `[2, 102, 12]` which means that the columns in the C++ layer only receive data from a single field in the structure even if it evolves over time. Args: new_structure: The new structure to use. Must be a superset of the previous structure. """ # Create columns for new paths and record column index numbers in the # `_structure` itself. self._structure = tree.unflatten_as(new_structure, [ self._maybe_create_column(path, column_idx) for path, column_idx in tree.flatten_with_path(new_structure) ])
def sample_trajectory(client: client_lib.Client, table: str, structure: Any) -> replay_sample.ReplaySample: """Temporary helper method for sampling a trajectory. Note! This function is only intended to make it easier for alpha testers to experiment with the new API. It will be removed before this file is made public. Args: client: Client connected to the server to sample from. table: Name of the table to sample from. structure: Structure to unpack flat data as. Returns: ReplaySample with trajectory unpacked as `structure` in `data`-field. """ sampler = client._client.NewSampler(table, 1, 1, 1) # pylint: disable=protected-access sample = sampler.GetNextSample() return replay_sample.ReplaySample( info=replay_sample.SampleInfo(key=int(sample[0][0]), probability=float(sample[1][0]), table_size=int(sample[2][0]), priority=float(sample[3][0])), data=tree.unflatten_as(structure, sample[4:]))
def stack_observations(obs_list): obs_list = [ np.stack(obs) for obs in zip(*[tree.flatten(obs) for obs in obs_list]) ] obs_dict = tree.unflatten_as(observation_spec, obs_list) obs_dict.pop("aux_tasks_reward") return obs_dict
def testFlatten_bytearrayIsNotFlattened(self): structure = bytearray("bytes in an array", "ascii") flattened = tree.flatten(structure) self.assertLen(flattened, 1) self.assertEqual(flattened, [structure]) self.assertEqual( structure, tree.unflatten_as(bytearray("hello", "ascii"), flattened))
def unflatten_like(a, pytree): """Take 1-D array produced by flatten() and unflatten like pytree.""" seq = tree.flatten(pytree) seq_sizes = [np.reshape(x, -1).shape for x in seq] starts = [0] + list(np.cumsum(seq_sizes)) a_seq_flat = [a[starts[i]:starts[i + 1]] for i in range(len(starts) - 1)] a_seq = [np.reshape(x1, x2.shape) for x1, x2 in zip(a_seq_flat, seq)] return tree.unflatten_as(pytree, a_seq)
def test_iterate_nested_and_batched(self): with self._client.writer(100) as writer: for i in range(1000): writer.append({ 'observation': { 'data': np.zeros((3, 3), dtype=np.float32), 'extras': [ np.int64(10), np.ones([1], dtype=np.int32), ], }, 'reward': np.zeros((10, 10), dtype=np.float32), }) if i % 5 == 0 and i >= 100: writer.create_item(table='dist', num_timesteps=100, priority=1) dataset = reverb_dataset.ReplayDataset( self._client.server_address, table='dist', dtypes=(((tf.float32), (tf.int64, tf.int32)), tf.float32), shapes=((tf.TensorShape([3, 3]), (tf.TensorShape(None), tf.TensorShape([1]))), tf.TensorShape([10, 10])), max_in_flight_samples_per_worker=100) dataset = dataset.batch(3) structure = { 'observation': { 'data': tf.TensorSpec([3, 3], tf.float32), 'extras': [ tf.TensorSpec([], tf.int64), tf.TensorSpec([1], tf.int32), ], }, 'reward': tf.TensorSpec([], tf.int64), } got = self._sample_from(dataset, 10) self.assertLen(got, 10) for sample in got: self.assertIsInstance(sample, replay_sample.ReplaySample) transition = tree.unflatten_as(structure, tree.flatten(sample.data)) np.testing.assert_array_equal( transition['observation']['data'], np.zeros([3, 3, 3], dtype=np.float32)) np.testing.assert_array_equal( transition['observation']['extras'][0], np.ones([3], dtype=np.int64) * 10) np.testing.assert_array_equal( transition['observation']['extras'][1], np.ones([3, 1], dtype=np.int32)) np.testing.assert_array_equal( transition['reward'], np.zeros([3, 10, 10], dtype=np.float32))
def testFlattenAndUnflatten_withDicts(self): # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. named_tuple = collections.namedtuple("A", ("b", "c")) mess = [ "z", named_tuple(3, 4), { "c": [ 1, collections.OrderedDict([ ("b", 3), ("a", 2), ]), ], "b": 5 }, 17 ] flattened = tree.flatten(mess) self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17]) structure_of_mess = [ 14, named_tuple("a", True), { "c": [ 0, collections.OrderedDict([ ("b", 9), ("a", 8), ]), ], "b": 3 }, "hi everybody", ] self.assertEqual(mess, tree.unflatten_as(structure_of_mess, flattened)) # Check also that the OrderedDict was created, with the correct key order. unflattened_ordered_dict = tree.unflatten_as(structure_of_mess, flattened)[2]["c"][1] self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
def testMappingProxyType(self): if six.PY2: self.skipTest("Python 2 does not support mapping proxy type.") structure = types.MappingProxyType({"a": 1, "b": (2, 3)}) expected = types.MappingProxyType({"a": 4, "b": (5, 6)}) self.assertEqual(tree.flatten(structure), [1, 2, 3]) self.assertEqual(tree.unflatten_as(structure, [4, 5, 6]), expected) self.assertEqual(tree.map_structure(lambda v: v + 3, structure), expected)
def structure_like(tree1, tree2): # pylint: disable=g-doc-args, g-doc-return-or-yield """Makes tree1 have same structure as tree2.""" flat_paths1 = tree.flatten_with_path(tree.map_structure(lambda x: 0, tree1)) flat_paths2 = tree.flatten_with_path(tree.map_structure(lambda x: 0, tree2)) assert list(sorted(flat_paths1)) == list(sorted(flat_paths2)), ( 'paths of tree1 and tree2 do not match') indices = [flat_paths1.index(path) for path in flat_paths2] flat_tree1 = tree.flatten(tree1) reordered_flat_tree1 = [flat_tree1[i] for i in indices] return tree.unflatten_as(tree2, reordered_flat_tree1)
def _tree_filter(source, filter_): """Extract `filter_` from `source`.""" path_to_index = { path: i for i, (path, _) in enumerate(tree.flatten_with_path(filter_)) } flat_target = [None] * len(path_to_index) for path, leaf in tree.flatten_with_path(source): if path in path_to_index: flat_target[path_to_index[path]] = leaf return tree.unflatten_as(filter_, flat_target)
def fast_map_structure_with_path(func, *structure): """Faster map_structure_with_path implementation.""" head_entries_with_path = tree.flatten_with_path(structure[0]) if len(structure) > 1: tail_entries = (tree.flatten(s) for s in structure[1:]) entries_with_path = [ e[0] + e[1:] for e in zip(head_entries_with_path, *tail_entries) ] else: entries_with_path = head_entries_with_path # Arbitrarily choose one of the structures of the original sequence (the last) # to match the structure for the flattened sequence. return tree.unflatten_as(structure[-1], [func(*x) for x in entries_with_path])
async def materialize_value(value: Any) -> Any: """Returns a structure of materialized values. Args: value: A materialized value, a value reference, or structure materialized values and value references to materialize. """ async def _materialize(value: Any) -> Any: if isinstance(value, MaterializableValueReference): return await value.get_value() else: return value flattened = tree.flatten(value) flattened = await asyncio.gather(*[_materialize(v) for v in flattened]) return tree.unflatten_as(value, flattened)
def _tree_merge_into(source, target): """Update `target` with content of substructure `source`.""" path_to_index = { path: i for i, (path, _) in enumerate(tree.flatten_with_path(target)) } flat_target = tree.flatten(target) for path, leaf in tree.flatten_with_path(source): if path not in path_to_index: raise ValueError( f'Cannot expand {source} into {target} as it is not a sub structure.' ) flat_target[path_to_index[path]] = leaf return tree.unflatten_as(target, flat_target)
def history(self): """References to data, grouped by column and structured like appended data. Allows recently added data references to be accesses with list indexing semantics. However, instead of returning the raw references, the result is wrapped in a TrajectoryColumn object before being returned to the called. ```python writer = TrajectoryWriter(...) # Add three steps worth of data. first = writer.append({'a': 1, 'b': 100}) second = writer.append({'a': 2, 'b': 200}) third = writer.append({'a': 3, 'b': 300}) # Create a trajectory using the _ColumnHistory helpers. from_history = { 'all_a': writer.history['a'][:], 'first_b': writer.history['b'][0], 'last_b': writer.history['b'][-1], } writer.create_item(table='name', priority=1.0, trajectory=from_history) # Is the same as writing. explicit = { 'all_a': TrajectoryColumn([first['a'], second['a'], third['a']]), 'first_b': TrajectoryColumn([first['b']]), 'last_b': TrajectoryColumn([third['b']]), } writer.create_item(table='name', priority=1.0, trajectory=explicit) ``` Raises: RuntimeError: If `append` hasn't been called at least once before. """ if self._structure is None: raise RuntimeError( 'history cannot be accessed before `append` is called at least once.' ) reordered_flat_history = [ self._column_history[self._column_index_to_flat_structure_index[i]] for i in range(len(self._column_history)) ] return tree.unflatten_as(self._structure, reordered_flat_history)
def get_table_content(self, idx: int, structure=None): info = self.client.server_info(1) num_items = info[TABLES[idx]].current_size if num_items == 0: return [] sampler = self.client.sample(TABLES[idx], num_samples=num_items, emit_timesteps=False) flat_samples = [sample.data for sample in sampler] if structure: return [ tree.unflatten_as(structure, sample) for sample in flat_samples ] return flat_samples
def restore_from_path(ckpt_dir: str) -> CheckpointState: """Restore the state stored in ckpt_dir.""" array_path = os.path.join(ckpt_dir, _ARRAY_NAME) exemplar_path = os.path.join(ckpt_dir, _EXEMPLAR_NAME) with open(exemplar_path, 'rb') as f: exemplar = pickle.load(f) with open(array_path, 'rb') as f: files = np.load(f, allow_pickle=True) flat_state = [files[key] for key in files.files] unflattened_tree = tree.unflatten_as(exemplar, flat_state) def maybe_convert_to_python(value, numpy): return value if numpy else np.asscalar(value) return tree.map_structure(maybe_convert_to_python, unflattened_tree, exemplar)
def testAttrsFlattenAndUnflatten(self): class BadAttr(object): """Class that has a non-iterable __attrs_attrs__.""" __attrs_attrs__ = None @attr.s class SampleAttr(object): field1 = attr.ib() field2 = attr.ib() field_values = [1, 2] sample_attr = SampleAttr(*field_values) self.assertFalse(tree._is_attrs(field_values)) self.assertTrue(tree._is_attrs(sample_attr)) flat = tree.flatten(sample_attr) self.assertEqual(field_values, flat) restructured_from_flat = tree.unflatten_as(sample_attr, flat) self.assertIsInstance(restructured_from_flat, SampleAttr) self.assertEqual(restructured_from_flat, sample_attr) # Check that flatten fails if attributes are not iterable with self.assertRaisesRegex(TypeError, "object is not iterable"): flat = tree.flatten(BadAttr())