def to_tree_arrays(list_of_trees): """Convert a list of pytrees into a pytree of stacked jnp.arrays. Args: list_of_trees: A list of pytrees containing numbers as leaves. Returns: A pytree of jnp.arrays having the same structure as the elements of `list_of_trees` Example: >>> to_tree_arrays([ (1, {"a": jnp.array([1,2])}), (2, {"a": jnp.array([3,4])}) ]) (DeviceArray([1, 2], dtype=int32), {'a': DeviceArray([[1, 2], [3, 4]], dtype=int32)}) """ if not list_of_trees: return list_of_trees trees_list = jax.tree_transpose( jax.tree_structure([0] * len(list_of_trees)), jax.tree_structure(list_of_trees[0]), list_of_trees) trees_array = jax.tree_multimap(lambda _, ls: jnp.stack(ls), list_of_trees[0], trees_list) return trees_array
def test_tree_transpose(self): outerdef = jax.tree_structure(FlatMap({"a": 1, "b": 2})) innerdef = jax.tree_structure([1, 2]) self.assertEqual( [FlatMap({"a": 3, "b": 5}), FlatMap({"a": 4, "b": 6})], jax.tree_transpose( outerdef, innerdef, FlatMap({"a": [3, 4], "b": [5, 6]})))
def point_project_tree(tree_point_cloud, ts, n_basis, basis): def point_project_list(*args): return list(point_project_array(args, ts, n_basis, basis)) out = jax.tree_multimap(point_project_list, *tree_point_cloud) original_struct = jax.tree_structure(tree_point_cloud[0]) mapped_struct = jax.tree_structure(list(range(n_basis))) return jax.tree_transpose(original_struct, mapped_struct, out)
def function_project_tree(source_params, source_basis, target_basis, n_basis): def function_project_list(*args): return list( function_project_array(args, source_basis, target_basis, n_basis)) out = jax.tree_multimap(function_project_list, *source_params) original_struct = jax.tree_structure(source_params[0]) mapped_struct = jax.tree_structure(list(range(n_basis))) return jax.tree_transpose(original_struct, mapped_struct, out)
def tree_vmap(f, lst): stacked = jax.tree_map(lambda args: jnp.stack(args), lst) out_stacked = jax.vmap(f)(stacked) _, outer_treedef = jax.tree_flatten([None] * len(lst)) _, inner_treedef = jax.tree_flatten(out_stacked) out_unstacked_transposed = jax.tree_map(list, out_stacked) out_unstacked = jax.tree_transpose( outer_treedef, inner_treedef, out_unstacked_transposed ) return out_unstacked
def write_gradient_histogram(writer, step, *, grads=None, updates=None): """Log computed gradients and/or updates histograms.""" histograms = {"grad": grads, "update": updates} histograms = {k: v for k, v in histograms.items() if v is not None} if not histograms: return # Transpose a histograms dict from # {"grad": {"param1": Tensor}, "update": {"param1": Tensor}} to # {"param1": {"grad": Tensor, "update": Tensor}} such that the gradient and # the transformed updates appear next to each other in tensorboard. histograms = jax.tree_transpose( jax.tree_structure({k: 0 for k in histograms.keys()}), jax.tree_structure(next(iter(histograms.values()))), histograms) histograms = utils.flatten_dict(histograms, sep=".") writer.write_histograms(step, histograms)