Beispiel #1
0
    def derive_metrics(self, global_state):
        metrics = collections.OrderedDict()

        def add_metrics(tuple_path, subquery, subquery_global_state):
            metrics.update({
                '/'.join(str(s) for s in tuple_path + (name, )): metric
                for name, metric in subquery.derive_metrics(
                    subquery_global_state).items()
            })

        tree.map_structure_with_path_up_to(self._queries, add_metrics,
                                           self._queries, global_state)

        return metrics
Beispiel #2
0
def map_structure_with_paths(func, *structure, **kwargs):
  """Applies `func` to each entry in `structure` and returns a new structure.

  Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in
  `structure[i]` and `path` is the common path to x[i] in the structures.  All
  structures in `structure` must have the same arity, and the return value will
  contain the results with the same structure layout.

  Args:
    func: A callable with the signature func(path, *values, **kwargs) that is
      evaluated on the leaves of the structure.
    *structure: A variable number of compatible structures to process.
    **kwargs: Optional kwargs to be passed through to func. Special kwarg
      `check_types` is not passed to func, but instead determines whether the
      types of iterables within the structures have to be same (e.g.
      `map_structure(func, [1], (1,))` raises a `TypeError` exception).
      To allow this set this argument to `False`.

  Returns:
    A structure of the same form as the input structures whose leaves are the
    result of evaluating func on corresponding leaves of the input structures.

  Raises:
    TypeError: If `func` is not callable or if the structures do not match
      each other by depth tree.
    TypeError: If `check_types` is not `False` and the two structures differ in
      the type of sequence in any of their substructures.
    ValueError: If no structures are provided.
  """
  def wrapper_func(tuple_path, *inputs, **kwargs):
    string_path = '/'.join(str(s) for s in tuple_path)
    return func(string_path, *inputs, **kwargs)

  return map_structure_with_path_up_to(structure[0], wrapper_func,
                                       *structure, **kwargs)
Beispiel #3
0
def map_structure_with_tuple_paths_up_to(shallow_structure, func, *structures,
                                         **kwargs):
  """Wraps nest.map_structure_with_path_up_to, with structure/type checking."""
  if not structures:
    raise ValueError('Cannot map over no sequences')

  check_types = kwargs.get('check_types', True)
#   kwargs.pop('check_types', None)  # DisableOnExport

  for input_tree in structures:
    assert_shallow_structure(
        shallow_structure, input_tree, check_types=check_types)

  return dm_tree.map_structure_with_path_up_to(shallow_structure, func,
                                               *structures, **kwargs)
Beispiel #4
0
def map_structure_with_tuple_paths_up_to(shallow_structure, func, *structures,
                                         expand_composites=False, **kwargs):
  """Wraps nest.map_structure_with_path_up_to, with structure/type checking."""
  if not structures:
    raise ValueError('Cannot map over no sequences')

  check_types = kwargs.get('check_types', True)
#   kwargs.pop('check_types', None)  # DisableOnExport

  if expand_composites:
    raise NotImplementedError(
        '`expand_composites=True` is not supported in JAX.')

  for input_tree in structures:
    assert_shallow_structure(
        shallow_structure, input_tree, check_types=check_types)
  return dm_tree.map_structure_with_path_up_to(
      shallow_structure, func, *structures, **kwargs)
Beispiel #5
0
def map_structure_with_tuple_paths_up_to(func, *structures, **kwargs):
  return map_structure_with_path_up_to(func, *structures, **kwargs)