Esempio n. 1
0
def to_tree_arrays(list_of_trees):
    """Convert a list of pytrees into a pytree of stacked jnp.arrays.

  Args:
    list_of_trees: A list of pytrees containing numbers as leaves.

  Returns:
    A pytree of jnp.arrays having the same structure as the elements of
    `list_of_trees`

  Example:
    >>> to_tree_arrays([
        (1, {"a": jnp.array([1,2])}),
        (2, {"a": jnp.array([3,4])})
      ])
    (DeviceArray([1, 2], dtype=int32),
     {'a': DeviceArray([[1, 2],
                        [3, 4]], dtype=int32)})
  """
    if not list_of_trees:
        return list_of_trees

    trees_list = jax.tree_transpose(
        jax.tree_structure([0] * len(list_of_trees)),
        jax.tree_structure(list_of_trees[0]), list_of_trees)

    trees_array = jax.tree_multimap(lambda _, ls: jnp.stack(ls),
                                    list_of_trees[0], trees_list)

    return trees_array
Esempio n. 2
0
 def test_tree_transpose(self):
   outerdef = jax.tree_structure(FlatMap({"a": 1, "b": 2}))
   innerdef = jax.tree_structure([1, 2])
   self.assertEqual(
       [FlatMap({"a": 3, "b": 5}), FlatMap({"a": 4, "b": 6})],
       jax.tree_transpose(
           outerdef, innerdef, FlatMap({"a": [3, 4], "b": [5, 6]})))
Esempio n. 3
0
def point_project_tree(tree_point_cloud, ts, n_basis, basis):
    def point_project_list(*args):
        return list(point_project_array(args, ts, n_basis, basis))

    out = jax.tree_multimap(point_project_list, *tree_point_cloud)
    original_struct = jax.tree_structure(tree_point_cloud[0])
    mapped_struct = jax.tree_structure(list(range(n_basis)))
    return jax.tree_transpose(original_struct, mapped_struct, out)
Esempio n. 4
0
def function_project_tree(source_params, source_basis, target_basis, n_basis):
    def function_project_list(*args):
        return list(
            function_project_array(args, source_basis, target_basis, n_basis))

    out = jax.tree_multimap(function_project_list, *source_params)
    original_struct = jax.tree_structure(source_params[0])
    mapped_struct = jax.tree_structure(list(range(n_basis)))
    return jax.tree_transpose(original_struct, mapped_struct, out)
Esempio n. 5
0
def tree_vmap(f, lst):
    stacked = jax.tree_map(lambda args: jnp.stack(args), lst)
    out_stacked = jax.vmap(f)(stacked)
    _, outer_treedef = jax.tree_flatten([None] * len(lst))
    _, inner_treedef = jax.tree_flatten(out_stacked)
    out_unstacked_transposed = jax.tree_map(list, out_stacked)
    out_unstacked = jax.tree_transpose(
        outer_treedef, inner_treedef, out_unstacked_transposed
    )
    return out_unstacked
def write_gradient_histogram(writer, step, *, grads=None, updates=None):
    """Log computed gradients and/or updates histograms."""
    histograms = {"grad": grads, "update": updates}
    histograms = {k: v for k, v in histograms.items() if v is not None}
    if not histograms: return

    # Transpose a histograms dict from
    # {"grad": {"param1": Tensor}, "update": {"param1": Tensor}} to
    # {"param1": {"grad": Tensor, "update": Tensor}} such that the gradient and
    # the transformed updates appear next to each other in tensorboard.
    histograms = jax.tree_transpose(
        jax.tree_structure({k: 0
                            for k in histograms.keys()}),
        jax.tree_structure(next(iter(histograms.values()))), histograms)

    histograms = utils.flatten_dict(histograms, sep=".")
    writer.write_histograms(step, histograms)