def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_inputs=False): """Do a backward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the backward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the generators of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of tf.Operation whithin which the search is restricted. If within_ops is None, the search is performed within the whole graph. stop_at_ts: an iterable of tensors at which the graph walk stops. control_inputs: if True, control inputs will be used while moving backward. Returns: A Python set of all the tf.Operation behind seed_ops. Raises: TypeError: if seed_ops or within_ops cannot be converted to a list of tf.Operation. """ if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_generating_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) seed_ops = frozenset(util.make_list_of_op(seed_ops)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.inputs: if new_t in stop_at_ts: continue if new_t.op not in result and is_within(new_t.op): new_wave.add(new_t.op) if control_inputs: for new_op in op.control_inputs: if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, control_outputs=True): """Do a forward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the forward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the consumers of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of tf.Operation whithin which the search is restricted. If within_ops is None, the search is performed within the whole graph. control_outputs: an object convertible to a control output dictionary (see function util.convert_to_control_outputs for more details). If the dictionary can be created, it will be used while walking the graph forward. Returns: A Python set of all the tf.Operation ahead of seed_ops. Raises: TypeError: if seed_ops or within_ops cannot be converted to a list of tf.Operation. """ if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return set() if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = get_consuming_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) control_outputs = util.convert_to_control_outputs(seed_ops, control_outputs) seed_ops = frozenset(seed_ops) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = set(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.outputs: for new_op in new_t.consumers(): if new_op not in result and is_within(new_op): new_wave.add(new_op) if control_outputs is not None and op in control_outputs: for new_op in control_outputs[op]: if new_op not in result and is_within(new_op): new_wave.add(new_op) result.update(new_wave) wave = new_wave if not inclusive: result.difference_update(seed_ops) return result
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_outputs=None): """Do a forward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the forward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the consumers of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of `tf.Operation` within which the search is restricted. If `within_ops` is `None`, the search is performed within the whole graph. stop_at_ts: an iterable of tensors at which the graph walk stops. control_outputs: a `util.ControlOutputs` instance or None. If not `None`, it will be used while walking the graph forward. Returns: A Python set of all the `tf.Operation` ahead of `seed_ops`. Raises: TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of `tf.Operation`. """ _, control_outputs = check_cios(False, control_outputs) if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_consuming_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) seed_ops = frozenset(seed_ops) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.outputs: if new_t in stop_at_ts: continue for new_op in new_t.consumers(): if new_op not in result and is_within(new_op): new_wave.add(new_op) if control_outputs is not None: for new_op in control_outputs.get(op): if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result