コード例 #1
0
def tree_flatten(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
  """Flattens a pytree.

  Args:
    tree: a pytree to flatten.
    is_leaf: an optionally specified function that will be called at each
      flattening step. It should return a boolean, with true stopping the
      traversal and the whole subtree being treated as a leaf, and false
      indicating the flattening should traverse the current object.
  Returns:
    A pair where the first element is a list of leaf values and the second
    element is a treedef representing the structure of the flattened tree.
  """
  return pytree.flatten(tree, is_leaf)
コード例 #2
0
def tree_structure(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
  """Gets the treedef for a pytree."""
  return pytree.flatten(tree, is_leaf)[1]
コード例 #3
0
def tree_leaves(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
  """Gets the leaves of a pytree."""
  return pytree.flatten(tree, is_leaf)[0]
コード例 #4
0
def _process_pytree(process_node, tree):
  leaves, treedef = pytree.flatten(tree)
  return treedef.walk(process_node, None, leaves), treedef
コード例 #5
0
ファイル: tree_util.py プロジェクト: matthewfeickert/jax
def tree_structure(tree):
  """Gets the treedef for a pytree."""
  return pytree.flatten(tree)[1]
コード例 #6
0
ファイル: tree_util.py プロジェクト: matthewfeickert/jax
def tree_leaves(tree):
  """Gets the leaves of a pytree."""
  return pytree.flatten(tree)[0]