Example #1
0
def _yield_sorted_items(iterable):
    """Yield (key, value) pairs for `iterable` in a deterministic order.

  For Sequences, the key will be an int, the array index of a value.
  For Mappings, the key will be the dictionary key.
  For objects (e.g. namedtuples), the key will be the attribute name.

  In all cases, the keys will be iterated in sorted order.

  Args:
    iterable: an iterable.

  Yields:
    The iterable's (key, value) pairs, in order of sorted keys.
  """
    if not is_nested(iterable):
        raise ValueError(f'{iterable} is not an iterable')

    top_structure = dm_tree.traverse(
        lambda x: None  # pylint: disable=g-long-lambda
        if x is iterable else False,
        iterable)

    for p, v in dm_tree.flatten_with_path_up_to(top_structure, iterable):
        yield p[0], v
Example #2
0
def flatten_with_tuple_paths_up_to(shallow_structure,
                                   input_structure,
                                   check_types=True,
                                   expand_composites=False):
  if expand_composites:
    raise NotImplementedError(
        '`expand_composites=True` is not supported in JAX.')
  return dm_tree.flatten_with_path_up_to(shallow_structure,
                                         input_structure,
                                         check_types)
Example #3
0
def flatten_with_tuple_paths_up_to(shallow_structure,
                                   input_structure,
                                   check_types=True):
  return flatten_with_path_up_to(shallow_structure, input_structure,
                                 check_types)
Example #4
0
 def get_paths_and_values(shallow_tree, input_tree):
     path_value_pairs = tree.flatten_with_path_up_to(
         shallow_tree, input_tree)
     paths = [p for p, _ in path_value_pairs]
     values = [v for _, v in path_value_pairs]
     return paths, values