def _flatten_tree(tree, leaves=None): """Flatten a tree into a list. Args: tree: iterable or not. If iterable, its elements (child) can also be iterable or not. leaves: list to which the tree leaves are appended (None by default). Returns: A list of all the leaves in the tree. """ if leaves is None: leaves = [] if isinstance(tree, dict): for _, child in iteritems(tree): _flatten_tree(child, leaves) elif util.is_iterable(tree): for child in tree: _flatten_tree(child, leaves) else: leaves.append(tree) return leaves
def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, within_ops_fn=None, stop_at_ts=(), control_inputs=False): """Do a backward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the backward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the generators of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of `gde.Node` within which the search is restricted. If `within_ops` is `None`, the search is performed within the whole graph. within_ops_fn: if provided, a function on ops that should return True iff the op is within the graph traversal. This can be used along within_ops, in which case an op is within if it is also in within_ops. stop_at_ts: an iterable of tensors at which the graph walk stops. control_inputs: if True, control inputs will be used while moving backward. Returns: A Python set of all the `gde.Node` behind `seed_ops`. Raises: TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of `gde.Node`. """ if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_generating_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) seed_ops = frozenset(util.make_list_of_op(seed_ops)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(operator): return (within_ops is None or operator in within_ops) and (within_ops_fn is None or within_ops_fn(operator)) result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.inputs: if new_t in stop_at_ts: continue if new_t.op not in result and is_within(new_t.op): new_wave.add(new_t.op) if control_inputs: for new_op in op.control_inputs: if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result