def concat_obs_if_necessary(obs: TensorStructType): """Concat model outs if they are original tuple observations.""" if isinstance(obs, (list, tuple)): obs = tf.concat(obs, axis=-1) elif isinstance(obs, dict): obs = tf.concat( [ tf.expand_dims(val, 1) if len(val.shape) == 1 else val for val in tree.flatten(obs.values()) ], axis=-1, ) return obs
def _unbatch_helper(v: TensorStructType, max_len: int) -> TensorStructType: """Recursively unpacks the repeat dimension (max_len).""" if isinstance(v, dict): return {k: _unbatch_helper(u, max_len) for (k, u) in v.items()} elif isinstance(v, tuple): return tuple(_unbatch_helper(u, max_len) for u in v) elif isinstance(v, RepeatedValues): unbatched = _unbatch_helper(v.values, max_len) return [ RepeatedValues(u, v.lengths[:, i, ...], v.max_len) for i, u in enumerate(unbatched) ] else: return [v[:, i, ...] for i in range(max_len)]
def _get_batch_dim_helper(v: TensorStructType) -> int: """Tries to find the batch dimension size of v, or None.""" if isinstance(v, dict): for u in v.values(): return _get_batch_dim_helper(u) elif isinstance(v, tuple): return _get_batch_dim_helper(v[0]) elif isinstance(v, RepeatedValues): return _get_batch_dim_helper(v.values) else: B = v.shape[0] if hasattr(B, "value"): B = B.value # TensorFlow return B