def test_fast_map_structure_with_path(self): structure = { 'a': { 'b': np.array([0.0]) }, 'c': (np.array([1.0]), np.array([2.0])), 'd': [np.array(3.0), np.array(4.0)], } def map_fn(path: Sequence[str], x: np.ndarray, y: np.ndarray): return x + y + len(path) single_arg_map_fn = functools.partial(map_fn, y=np.array([0.0])) expected_mapped_structure = (tree.map_structure_with_path( single_arg_map_fn, structure)) mapped_structure = (tree_utils.fast_map_structure_with_path( single_arg_map_fn, structure)) self.assertEqual(mapped_structure, expected_mapped_structure) expected_double_mapped_structure = (tree.map_structure_with_path( map_fn, structure, mapped_structure)) double_mapped_structure = (tree_utils.fast_map_structure_with_path( map_fn, structure, mapped_structure)) self.assertEqual(double_mapped_structure, expected_double_mapped_structure)
def _save_value_functions(self, checkpoint_dir): tree.map_structure_with_path( lambda path, Q: Q.save_weights( os.path.join( checkpoint_dir, '-'.join(('Q', *[str(x) for x in path]))), save_format='tf'), self.Qs)
def add_samples(self, samples): num_samples = tree.flatten(samples)[0].shape[0] assert (('episode_index_forwards' in samples.keys()) is ('episode_index_backwards' in samples.keys())) if 'episode_index_forwards' not in samples.keys(): samples['episode_index_forwards'] = np.full( (num_samples, *self.fields['episode_index_forwards'].shape), self.fields['episode_index_forwards'].default_value, dtype=self.fields['episode_index_forwards'].dtype) samples['episode_index_backwards'] = np.full( (num_samples, *self.fields['episode_index_backwards'].shape), self.fields['episode_index_backwards'].default_value, dtype=self.fields['episode_index_backwards'].dtype) index = np.arange( self._pointer, self._pointer + num_samples) % self._max_size def add_sample(path, data, new_values, field): assert new_values.shape[0] == num_samples, ( new_values.shape, num_samples) data[index] = new_values tree.map_structure_with_path( add_sample, self.data, samples, self.fields) self._advance(num_samples)
def add_samples(self, samples): # samples: dict(7(1000)) num_samples = tree.flatten(samples)[0].shape[0] assert (('episode_index_forwards' in samples.keys()) is ('episode_index_backwards' in samples.keys())) if 'episode_index_forwards' not in samples.keys(): samples['episode_index_forwards'] = np.full( (num_samples, *self.fields['episode_index_forwards'].shape), self.fields['episode_index_forwards'].default_value, dtype=self.fields['episode_index_forwards'].dtype) samples['episode_index_backwards'] = np.full( (num_samples, *self.fields['episode_index_backwards'].shape), self.fields['episode_index_backwards'].default_value, dtype=self.fields['episode_index_backwards'].dtype) # 不会超过最大容量 index = np.arange(self._pointer, self._pointer + num_samples) % self._max_size def add_sample(path, data, new_values, field): assert new_values.shape[0] == num_samples, (new_values.shape, num_samples) data[index] = new_values # map_structure_with_path 中, dict 是外部结构不是叶子节点,多个dict 必须key 相同,value可相互运算 # 才能对 value 的叶子节点进行对应运算.func 中的 path 是 dict 的 key,如果是嵌套的,tuple 就是两个 key # 对 self.data 的每个 key 对应的value进行操作. map_structure 是对 tree.map_structure_with_path(add_sample, self.data, samples, self.fields) self._advance(num_samples)
def compress( self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"]) ) -> "SampleBatch": """Compresses the data buffers (by column) in place. Args: bulk: Whether to compress across the batch dimension (0) as well. If False will compress n separate list items, where n is the batch size. columns: The columns to compress. Default: Only compress the obs and new_obs columns. Returns: This very (now compressed) SampleBatch. """ def _compress_in_place(path, value): if path[0] not in columns: return curr = self for i, p in enumerate(path): if i == len(path) - 1: if bulk: curr[p] = pack(value) else: curr[p] = np.array([pack(o) for o in value]) curr = curr[p] tree.map_structure_with_path(_compress_in_place, self) return self
def decompress_if_needed( self, columns: Set[str] = frozenset(["obs", "new_obs"])) -> "SampleBatch": """Decompresses data buffers (per column if not compressed) in place. Args: columns: The columns to decompress. Default: Only decompress the obs and new_obs columns. Returns: This very (now uncompressed) SampleBatch. """ def _decompress_in_place(path, value): if path[0] not in columns: return curr = self for p in path[:-1]: curr = curr[p] # Bulk compressed. if is_compressed(value): curr[path[-1]] = unpack(value) # Non bulk compressed. elif len(value) > 0 and is_compressed(value[0]): curr[path[-1]] = np.array([unpack(o) for o in value]) tree.map_structure_with_path(_decompress_in_place, self) return self
def assert_tree_shape_prefix(tree: ArrayTree, shape_prefix: Sequence[int], *, ignore_nones: bool = False): """Asserts all tree leaves' shapes have the same prefix. Args: tree: a tree to assert. shape_prefix: an expected shapes' prefix. ignore_nones: whether to ignore `None`s in the tree. Raise: AssertionError: if some leaf's shape doesn't start with the expected prefix; if `ignore_nones` isn't set and the tree contains `None`s. """ if not ignore_nones: assert_tree_no_nones(tree) def _assert_fn(path, leaf): if leaf is None: return prefix = leaf.shape[:len(shape_prefix)] if prefix != shape_prefix: raise AssertionError( f"Tree leaf '{_format_tree_path(path)}' has a shape prefix " f"diffent from expected: {prefix} != {shape_prefix}.") dm_tree.map_structure_with_path(_assert_fn, tree)
def log_stats(ex, stats, step=None, sep='.'): def log(path, value): if isinstance(value, tf.Tensor): value = value.numpy() if isinstance(value, np.ndarray): value = value.mean() key = sep.join(map(str, path)) ex.log_scalar(key, value, step=step) tree.map_structure_with_path(log, stats)
def _merge_adresses_in(synthesis, not_criterias, to_merge): def _add_entry(synthesis, not_criterias, path, val): if list(path) not in not_criterias: if synthesis and path in tuple(zip(*synthesis))[0]: i = tuple(zip(*synthesis))[0].index(path) if val not in synthesis[i][1]: synthesis[i][1].append(val) else: synthesis.append((path, [val])) tree.map_structure_with_path(partial(_add_entry, synthesis, not_criterias), to_merge)
def add_hyperparam_group(self, group, suffix, defaults): """Adds new hyperparameters to the hyperparams dict.""" # Use default hyperparams unless overridden by group hyperparams group_dict = {key: key for key in defaults if key not in group} for key in group: if key != 'params': # Reserved keyword 'params' group_dict[key] = '%s_%s' % (key, suffix) self._hyperparameters[group_dict[key]] = group[key] # Set up params2hyperparams def set_p2h(k, _): self._params2hyperparams['/'.join(k)] = group_dict tree.map_structure_with_path(set_p2h, group['params'])
def create_param_groups(self, params, defaults): """Creates param-hyperparam mappings.""" if isinstance(params, list): for group_index, group in enumerate(params): # Add group to hyperparams and get this group's full hyperparameters self.add_hyperparam_group(group, group_index, defaults) else: mapping = {key: key for key in self._hyperparameters} def set_p2h(k, _): self._params2hyperparams['/'.join(k)] = mapping tree.map_structure_with_path(set_p2h, params)
def assert_dtype(self, test_dtype, module_fn: ModuleFn, shape, input_dtype): """Checks that modules accepting float32 input_dtype output test_dtype.""" if input_dtype != jnp.float32: self.skipTest('Skipping module with non-f32 input') def ones_creator(next_creator, shape, dtype, init, context): if context.full_name == 'vector_quantizer/embeddings': # NOTE: vector_quantizer/embeddings is created using a ctor argument # so dtype is not expected to follow input to __call__. dtype = test_dtype else: self.assertEqual(dtype, test_dtype, msg=context.full_name) # NOTE: We need to do this since some initializers (e.g. random.uniform) # do not support <32bit dtypes. This also makes the test run a bit faster. init = jnp.ones return next_creator(shape, dtype, init) def g(x): with hk.custom_creator(ones_creator): mod = module_fn() return mod(x) g = hk.transform_with_state(g) # No custom creator for state so we need to do this manually. def cast_if_floating(x): if jnp.issubdtype(x.dtype, jnp.floating): x = x.astype(test_dtype) return x def init_fn(rng, x): params, state = g.init(rng, x) state = jax.tree_map(cast_if_floating, state) return params, state x = np.ones(shape, test_dtype) rng = jax.random.PRNGKey(42) params, state = jax.eval_shape(init_fn, rng, x) for _ in range(2): y, state = jax.eval_shape(g.apply, params, state, rng, x) def assert_dtype(path, v): if jnp.issubdtype(v.dtype, jnp.floating): self.assertEqual(v.dtype, test_dtype, msg=path) tree.map_structure_with_path(assert_dtype, y) tree.map_structure_with_path(assert_dtype, state)
def call(*args, **kwargs): args_flat = [] for a in args: if type(a) is list: args_flat.extend(a) else: args_flat.append(a) args = args_flat # We have not built any placeholders yet: Do this once here, # then reuse the same placeholders each time we call this # function again. if symbolic_out[0] is None: with session_or_none.graph.as_default(): def _create_placeholders(path, value): if dynamic_shape: if len(value.shape) > 0: shape = (None,) + value.shape[1:] else: shape = () else: shape = value.shape return tf1.placeholder( dtype=value.dtype, shape=shape, name=".".join([str(p) for p in path]), ) placeholders = tree.map_structure_with_path( _create_placeholders, args ) for ph in tree.flatten(placeholders): args_placeholders.append(ph) placeholders = tree.map_structure_with_path( _create_placeholders, kwargs ) for k, ph in placeholders.items(): kwargs_placeholders[k] = ph symbolic_out[0] = fn(*args_placeholders, **kwargs_placeholders) feed_dict = dict(zip(args_placeholders, tree.flatten(args))) tree.map_structure( lambda ph, v: feed_dict.__setitem__(ph, v), kwargs_placeholders, kwargs, ) ret = session_or_none.run(symbolic_out[0], feed_dict) return ret
def assert_tree_all_equal_comparator(equality_comparator: TLeavesEqCmpFn, error_msg_fn: TLeavesEqCmpErrorFn, *trees: Sequence[ArrayTree]): """Assert all trees are equal using custom comparator for leaves.""" if len(trees) < 2: return assert_tree_all_equal_structs(*trees) def _tree_error_msg_fn(l_1: TLeaf, l_2: TLeaf, path: str, i_1: int, i_2: int): msg = error_msg_fn(l_1, l_2) return f"Trees {i_1} and {i_2} differ in leaves '{path}': {msg}." cmp = functools.partial(_assert_leaves_all_eq_comparator, equality_comparator, _tree_error_msg_fn) tree.map_structure_with_path(cmp, *trees)
def test_bias_parameter_shape(self): params = self.network.init(self.init_rng_key, _sample_input(self.input_shape)) self.assertLen(tree.flatten(params), 2) def check_params(path, param): if path[-1] == 'b': self.assertNotEqual(self.output_shape, param.shape) chex.assert_shape(param, (1, )) elif path[-1] == 'w': chex.assert_shape(param, self.weights_shape) else: self.fail('Unexpected parameter %s.' % path) tree.map_structure_with_path(check_params, params)
def assert_tree_no_nones(tree: ArrayTree): """Asserts tree does not contain `None`s. Args: tree: a tree to assert. Raises: AssertionError: if the tree contains `None`s. """ def _assert_fn(path, leaf): if leaf is None: raise AssertionError( f"`None` detected at '{_format_tree_path(path)}'.") dm_tree.map_structure_with_path(_assert_fn, tree)
def get_placeholder(*, space=None, value=None, name=None, time_axis=False, flatten=True): from ray.rllib.models.catalog import ModelCatalog if space is not None: if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): if flatten: return ModelCatalog.get_action_placeholder(space, None) else: return tree.map_structure_with_path( lambda path, component: get_placeholder( space=component, name=name + "." + ".".join([str(p) for p in path]), ), get_base_struct_from_space(space), ) return tf1.placeholder( shape=(None, ) + ((None, ) if time_axis else ()) + space.shape, dtype=tf.float32 if space.dtype == np.float64 else space.dtype, name=name, ) else: assert value is not None shape = value.shape[1:] return tf1.placeholder( shape=(None, ) + ((None, ) if time_axis else ()) + (shape if isinstance(shape, tuple) else tuple(shape.as_list())), dtype=tf.float32 if value.dtype == np.float64 else value.dtype, name=name, )
def signature(cls, environment_spec: specs.EnvironmentSpec, extras_spec: types.NestedSpec = ()): """This is a helper method for generating signatures for Reverb tables. Signatures are useful for validating data types and shapes, see Reverb's documentation for details on how they are used. Args: environment_spec: A `specs.EnvironmentSpec` whose fields are nested structures with leaf nodes that have `.shape` and `.dtype` attributes. This should come from the environment that will be used to generate the data inserted into the Reverb table. extras_spec: A nested structure with leaf nodes that have `.shape` and `.dtype` attributes. The structure (and shapes/dtypes) of this must be the same as the `extras` passed into `ReverbAdder.add`. Returns: A `Step` whose leaf nodes are `tf.TensorSpec` objects. """ spec_step = Step(observation=environment_spec.observations, action=environment_spec.actions, reward=environment_spec.rewards, discount=environment_spec.discounts, extras=extras_spec) return tree.map_structure_with_path(array_spec_to_tensor_spec, spec_step)
def signature(cls, environment_spec: specs.EnvironmentSpec, extras_spec: types.NestedSpec = ()): # This function currently assumes that self._discount is a scalar. # If it ever becomes a nested structure and/or a np.ndarray, this method # will need to know its structure / shape. This is because the signature # discount shape is the environment's discount shape and this adder's # discount shape broadcasted together. Also, the reward shape is this # signature discount shape broadcasted together with the environment # reward shape. As long as self._discount is a scalar, it will not affect # either the signature discount shape nor the signature reward shape, so we # can ignore it. rewards_spec, step_discounts_spec = tree_utils.broadcast_structures( environment_spec.rewards, environment_spec.discounts) rewards_spec = tree.map_structure(_broadcast_specs, rewards_spec, step_discounts_spec) step_discounts_spec = tree.map_structure(copy.deepcopy, step_discounts_spec) transition_spec = types.Transition( environment_spec.observations, environment_spec.actions, rewards_spec, step_discounts_spec, environment_spec.observations, # next_observation extras_spec) return tree.map_structure_with_path(base.spec_like_to_tensor_spec, transition_spec)
def add_learn_on_batch_results( self, results: Dict, policy_id: PolicyID = DEFAULT_POLICY_ID, ) -> None: """Adds a policy.learn_on_(loaded)?_batch() result to this builder. Args: results: The results returned by Policy.learn_on_batch or Policy.learn_on_loaded_batch. policy_id: The policy's ID, whose learn_on_(loaded)_batch method returned `results`. """ assert (not self.is_finalized ), "LearnerInfo already finalized! Cannot add more results." # No towers: Single CPU. if "tower_0" not in results: self.results_all_towers[policy_id].append(results) # Multi-GPU case: else: self.results_all_towers[policy_id].append( tree.map_structure_with_path( lambda p, *s: all_tower_reduce(p, *s), *(results.pop("tower_{}".format(tower_num)) for tower_num in range(self.num_devices)))) for k, v in results.items(): if k == LEARNER_STATS_KEY: for k1, v1 in results[k].items(): self.results_all_towers[policy_id][-1][ LEARNER_STATS_KEY][k1] = v1 else: self.results_all_towers[policy_id][-1][k] = v
def rows(self) -> Iterator[Dict[str, TensorType]]: """Returns an iterator over data rows, i.e. dicts with column values. Note that if `seq_lens` is set in self, we set it to [1] in the rows. Yields: Dict[str, TensorType]: The column values of the row in this iteration. Examples: >>> batch = SampleBatch({ ... "a": [1, 2, 3], ... "b": [4, 5, 6], ... "seq_lens": [1, 2] ... }) >>> for row in batch.rows(): print(row) {"a": 1, "b": 4, "seq_lens": [1]} {"a": 2, "b": 5, "seq_lens": [1]} {"a": 3, "b": 6, "seq_lens": [1]} """ # Do we add seq_lens=[1] to each row? seq_lens = None if self.get( SampleBatch.SEQ_LENS) is None else np.array([1]) self_as_dict = {k: v for k, v in self.items()} for i in range(self.count): yield tree.map_structure_with_path( lambda p, v: v[i] if p[0] != self.SEQ_LENS else seq_lens, self_as_dict, )
def rows(self) -> Iterator[Dict[str, TensorType]]: """Returns an iterator over data rows, i.e. dicts with column values. Note that if `seq_lens` is set in self, we set it to 1 in the rows. Yields: The column values of the row in this iteration. Examples: >>> from ray.rllib.policy.sample_batch import SampleBatch >>> batch = SampleBatch({ # doctest: +SKIP ... "a": [1, 2, 3], ... "b": [4, 5, 6], ... "seq_lens": [1, 2] ... }) >>> for row in batch.rows(): # doctest: +SKIP ... print(row) # doctest: +SKIP {"a": 1, "b": 4, "seq_lens": 1} {"a": 2, "b": 5, "seq_lens": 1} {"a": 3, "b": 6, "seq_lens": 1} """ seq_lens = None if self.get(SampleBatch.SEQ_LENS, 1) is None else 1 self_as_dict = {k: v for k, v in self.items()} for i in range(self.count): yield tree.map_structure_with_path( lambda p, v: v[i] if p[0] != self.SEQ_LENS else seq_lens, self_as_dict, )
def _slice(self, slice_: slice): """Helper method to handle SampleBatch slicing using a slice object. The returned SampleBatch uses the same underlying data object as `self`, so changing the slice will also change `self`. Note that only zero or positive bounds are allowed for both start and stop values. The slice step must be 1 (or None, which is the same). Args: slice_ (slice): The python slice object to slice by. Returns: SampleBatch: A new SampleBatch, however "linking" into the same data (sliced) as self. """ start = slice_.start or 0 stop = slice_.stop or len(self) assert start >= 0 and stop >= 0 and slice_.step in [1, None] if self.get(SampleBatch.SEQ_LENS) is not None and \ len(self[SampleBatch.SEQ_LENS]) > 0: # Build our slice-map, if not done already. if not self._slice_map: sum_ = 0 for i, l in enumerate(self[SampleBatch.SEQ_LENS]): for _ in range(l): self._slice_map.append((i, sum_)) sum_ += l self._slice_map.append((len(self[SampleBatch.SEQ_LENS]), sum_)) start_seq_len, start = self._slice_map[start] stop_seq_len, stop = self._slice_map[stop] if self.zero_padded: start = start_seq_len * self.max_seq_len stop = stop_seq_len * self.max_seq_len def map_(path, value): if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith( "state_in_"): return value[start:stop] else: return value[start_seq_len:stop_seq_len] data = tree.map_structure_with_path(map_, self) return SampleBatch( data, _is_training=self.is_training, _time_major=self.time_major, _zero_padded=self.zero_padded, _max_seq_len=self.max_seq_len if self.zero_padded else None, ) else: data = tree.map_structure(lambda value: value[start:stop], self) return SampleBatch( data, _is_training=self.is_training, _time_major=self.time_major, )
def assert_dtype( self, test_dtype: DType, module_fn: descriptors.ModuleFn, shape: Shape, input_dtype: DType, ): """Checks that modules accepting float32 input_dtype output test_dtype.""" if jax.local_devices()[0].platform != 'tpu': self.skipTest('bfloat16 only supported on TPU') if input_dtype != jnp.float32: self.skipTest('Skipping module without float32 input') rng = jax.random.PRNGKey(42) def g(x): mod = module_fn() return mod(x) init_fn, apply_fn = hk.transform_with_state(g) # Create state in f32 to start. # NOTE: We need to do this since some initializers (e.g. random.uniform) do # not support <32bit dtypes. x = jax.random.uniform(rng, shape) params, state = jax.eval_shape(init_fn, rng, x) # Cast f32 to test_dtype. def make_param(v): dtype = test_dtype if v.dtype == jnp.float32 else v.dtype return jnp.ones(v.shape, dtype) params, state = jax.tree_map(make_param, (params, state)) # test_dtype in should result in test_dtype out. x = x.astype(test_dtype) for _ in range(2): y, state = jax.eval_shape(apply_fn, params, state, rng, x) def assert_dtype(path, v): if v.dtype != jnp.int32: self.assertEqual(v.dtype, test_dtype, msg=path) tree.map_structure_with_path(assert_dtype, y) tree.map_structure_with_path(assert_dtype, state)
def field_from_gym_space(name, space): if isinstance(space, spaces.Box): if isinstance(name, (list, tuple)): name = '/'.join(name) return Field(name=name, dtype=space.dtype, shape=space.shape) elif isinstance(space, spaces.Dict): return tree.map_structure_with_path(field_from_gym_space, space.spaces) else: raise NotImplementedError(space)
def get_placeholder( *, space: Optional[gym.Space] = None, value: Optional[Any] = None, name: Optional[str] = None, time_axis: bool = False, flatten: bool = True ) -> "tf1.placeholder": """Returns a tf1.placeholder object given optional hints, such as a space. Note that the returned placeholder will always have a leading batch dimension (None). Args: space: An optional gym.Space to hint the shape and dtype of the placeholder. value: An optional value to hint the shape and dtype of the placeholder. name: An optional name for the placeholder. time_axis: Whether the placeholder should also receive a time dimension (None). flatten: Whether to flatten the given space into a plain Box space and then create the placeholder from the resulting space. Returns: The tf1 placeholder. """ from ray.rllib.models.catalog import ModelCatalog if space is not None: if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): if flatten: return ModelCatalog.get_action_placeholder(space, None) else: return tree.map_structure_with_path( lambda path, component: get_placeholder( space=component, name=name + "." + ".".join([str(p) for p in path]), ), get_base_struct_from_space(space), ) return tf1.placeholder( shape=(None,) + ((None,) if time_axis else ()) + space.shape, dtype=tf.float32 if space.dtype == np.float64 else space.dtype, name=name, ) else: assert value is not None shape = value.shape[1:] return tf1.placeholder( shape=(None,) + ((None,) if time_axis else ()) + (shape if isinstance(shape, tuple) else tuple(shape.as_list())), dtype=tf.float32 if value.dtype == np.float64 else value.dtype, name=name, )
def observation_spec(self): # Note: this function will be called before reset call is issued and # observation might not be ready. # But for this specific environment it works. obs = self._observation() def mk_spec(path_tuple, np_arr): return ArraySpec(np_arr.shape, np_arr.dtype, name='_'.join(path_tuple) + '_spec') return nest.map_structure_with_path(mk_spec, obs)
def testMapWithPathCompatibleStructures(self, s1, s2, check_types, expected): def path_and_sum(path, *values): return path, sum(values) result = tree.map_structure_with_path(path_and_sum, s1, s2, check_types=check_types) self.assertEqual(expected, result)
def finalize(self): self.is_finalized = True info = {} for policy_id, results_all_towers in self.results_all_towers.items(): # Reduce mean across all minibatch SGD steps (axis=0 to keep # all shapes as-is). info[policy_id] = tree.map_structure_with_path( all_tower_reduce, *results_all_towers) return info
def signature( cls, environment_spec: mava_specs.EnvironmentSpec, extras_spec: tf.TypeSpec = {}, ) -> tf.TypeSpec: # This function currently assumes that self._discount is a scalar. # If it ever becomes a nested structure and/or a np.ndarray, this method # will need to know its structure / shape. This is because the signature # discount shape is the environment's discount shape and this adder's # discount shape broadcasted together. Also, the reward shape is this # signature discount shape broadcasted together with the environment # reward shape. As long as self._discount is a scalar, it will not affect # either the signature discount shape nor the signature reward shape, so we # can ignore it. agent_specs = environment_spec.get_agent_specs() agents = environment_spec.get_agent_ids() env_extras_spec = environment_spec.get_extra_specs() extras_spec.update(env_extras_spec) obs_specs = {} act_specs = {} reward_specs = {} step_discount_specs = {} for agent in agents: rewards_spec, step_discounts_spec = tree_utils.broadcast_structures( agent_specs[agent].rewards, agent_specs[agent].discounts ) rewards_spec = tree.map_structure( _broadcast_specs, rewards_spec, step_discounts_spec ) step_discounts_spec = tree.map_structure(copy.deepcopy, step_discounts_spec) obs_specs[agent] = agent_specs[agent].observations act_specs[agent] = agent_specs[agent].actions reward_specs[agent] = rewards_spec step_discount_specs[agent] = step_discounts_spec transition_spec = [ obs_specs, act_specs, extras_spec, reward_specs, step_discount_specs, obs_specs, # next_observation extras_spec, ] return tree.map_structure_with_path( base.spec_like_to_tensor_spec, tuple(transition_spec) )