Ejemplo n.º 1
0
def _flatten_tree(tree, leaves=None):
    """Flatten a tree into a list.

  Args:
    tree: iterable or not. If iterable, its elements (child) can also be
      iterable or not.
    leaves: list to which the tree leaves are appended (None by default).
  Returns:
    A list of all the leaves in the tree.
  """
    if leaves is None:
        leaves = []
    if isinstance(tree, dict):
        for _, child in iteritems(tree):
            _flatten_tree(child, leaves)
    elif util.is_iterable(tree):
        for child in tree:
            _flatten_tree(child, leaves)
    else:
        leaves.append(tree)
    return leaves
Ejemplo n.º 2
0
def get_backward_walk_ops(seed_ops,
                          inclusive=True,
                          within_ops=None,
                          within_ops_fn=None,
                          stop_at_ts=(),
                          control_inputs=False):
    """Do a backward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the backward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the generators of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of `gde.Node` within which the search is
      restricted. If `within_ops` is `None`, the search is performed within
      the whole graph.
    within_ops_fn: if provided, a function on ops that should return True iff
      the op is within the graph traversal. This can be used along within_ops,
      in which case an op is within if it is also in within_ops.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_inputs: if True, control inputs will be used while moving backward.
  Returns:
    A Python set of all the `gde.Node` behind `seed_ops`.
  Raises:
    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
      `gde.Node`.
  """
    if not util.is_iterable(seed_ops):
        seed_ops = [seed_ops]
    if not seed_ops:
        return []
    if isinstance(seed_ops[0], Tensor):
        ts = util.make_list_of_t(seed_ops, allow_graph=False)
        seed_ops = util.get_generating_ops(ts)
    else:
        seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

    stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
    seed_ops = frozenset(util.make_list_of_op(seed_ops))
    if within_ops:
        within_ops = util.make_list_of_op(within_ops, allow_graph=False)
        within_ops = frozenset(within_ops)
        seed_ops &= within_ops

    def is_within(operator):
        return (within_ops is None
                or operator in within_ops) and (within_ops_fn is None
                                                or within_ops_fn(operator))

    result = list(seed_ops)
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        for op in wave:
            for new_t in op.inputs:
                if new_t in stop_at_ts:
                    continue
                if new_t.op not in result and is_within(new_t.op):
                    new_wave.add(new_t.op)
            if control_inputs:
                for new_op in op.control_inputs:
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
        util.concatenate_unique(result, new_wave)
        wave = new_wave
    if not inclusive:
        result = [op for op in result if op not in seed_ops]
    return result