def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_inputs=False): """Do a backward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the backward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the generators of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of tf.Operation whithin which the search is restricted. If within_ops is None, the search is performed within the whole graph. stop_at_ts: an iterable of tensors at which the graph walk stops. control_inputs: if True, control inputs will be used while moving backward. Returns: A Python set of all the tf.Operation behind seed_ops. Raises: TypeError: if seed_ops or within_ops cannot be converted to a list of tf.Operation. """ if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_generating_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) seed_ops = frozenset(util.make_list_of_op(seed_ops)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.inputs: if new_t in stop_at_ts: continue if new_t.op not in result and is_within(new_t.op): new_wave.add(new_t.op) if control_inputs: for new_op in op.control_inputs: if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result
def _build_dependency_dict(self): """Build a dictionary of dependencies among nodes. """ open_set = Queue.Queue() closed_set = set() dep_dict = {} for op in self._seed_ops: open_set.put(op) reachable_ops = set( ge.get_walks_intersection_ops(list(self._seed_ops), list(self._grad_ops))) # traversal in the fw phase while not open_set.empty(): src_op = open_set.get() # do action for src_op dep_ops = set(src_op.control_inputs) for t in src_op.inputs: dep_ops |= set(util.get_generating_ops(t)) dep_ops &= reachable_ops dep_dict[src_op] = dep_ops next_ops = set() for t in src_op.outputs: next_ops |= set(util.get_consuming_ops(t)) for op in next_ops: if op in closed_set: continue if op not in open_set.queue: open_set.put(op) closed_set.add(src_op) return dep_dict