def _check_ts_compatibility(ts0, ts1): """Make sure the shape and dtype of the two tensor's lists are compatible. Args: ts0: an object convertible to a list of `tf.Tensor`. ts1: an object convertible to a list of `tf.Tensor`. Raises: ValueError: if any pair of tensors (same index in ts0 and ts1) have a dtype or a shape which is not compatible. """ ts0 = _util.make_list_of_t(ts0) ts1 = _util.make_list_of_t(ts1) if len(ts0) != len(ts1): raise ValueError("ts0 and ts1 have different sizes: {} != {}".format( len(ts0), len(ts1))) for t0, t1 in zip(ts0, ts1): # check dtype dtype0, dtype1 = t0.dtype, t1.dtype if not dtype0.is_compatible_with(dtype1): raise ValueError("Dtypes {} and {} are not compatible.".format(dtype0, dtype1)) # check shape shape0, shape1 = t0.get_shape(), t1.get_shape() if not shape0.is_compatible_with(shape1): raise ValueError("Shapes {} and {} are not compatible.".format(shape0, shape1))
def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_inputs=False): """Do a backward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the backward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the generators of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of tf.Operation whithin which the search is restricted. If within_ops is None, the search is performed within the whole graph. stop_at_ts: an iterable of tensors at which the graph walk stops. control_inputs: if True, control inputs will be used while moving backward. Returns: A Python set of all the tf.Operation behind seed_ops. Raises: TypeError: if seed_ops or within_ops cannot be converted to a list of tf.Operation. """ if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_generating_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) seed_ops = frozenset(util.make_list_of_op(seed_ops)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.inputs: if new_t in stop_at_ts: continue if new_t.op not in result and is_within(new_t.op): new_wave.add(new_t.op) if control_inputs: for new_op in op.control_inputs: if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result
def _check_ts_compatibility(ts0, ts1): """Make sure the shape and dtype of the two tensor's lists are compatible. Args: ts0: an object convertible to a list of `tf.Tensor`. ts1: an object convertible to a list of `tf.Tensor`. Raises: ValueError: if any pair of tensors (same index in ts0 and ts1) have a dtype or a shape which is not compatible. """ ts0 = _util.make_list_of_t(ts0) ts1 = _util.make_list_of_t(ts1) if len(ts0) != len(ts1): raise ValueError("ts0 and ts1 have different sizes: {} != {}".format( len(ts0), len(ts1))) for t0, t1 in zip(ts0, ts1): # check dtype dtype0, dtype1 = t0.dtype, t1.dtype if not dtype0.is_compatible_with(dtype1): raise ValueError("Dtypes {} and {} are not compatible.".format( dtype0, dtype1)) # check shape shape0, shape1 = t0.get_shape(), t1.get_shape() if not shape0.is_compatible_with(shape1): raise ValueError("Shapes {} and {} are not compatible.".format( shape0, shape1))
def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_inputs=False): """Do a backward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the backward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the generators of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of tf.Operation whithin which the search is restricted. If within_ops is None, the search is performed within the whole graph. stop_at_ts: an iterable of tensors at which the graph walk stops. control_inputs: if True, control inputs will be used while moving backward. Returns: A Python set of all the tf.Operation behind seed_ops. Raises: TypeError: if seed_ops or within_ops cannot be converted to a list of tf.Operation. """ if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_generating_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) seed_ops = frozenset(util.make_list_of_op(seed_ops)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.inputs: if new_t in stop_at_ts: continue if new_t.op not in result and is_within(new_t.op): new_wave.add(new_t.op) if control_inputs: for new_op in op.control_inputs: if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result
def __init__(self, inside_ops=(), passthrough_ts=()): """Create a subgraph containing the given ops and the "passthrough" tensors. Args: inside_ops: an object convertible to a list of tf.Operation. This list defines all the operations in the subgraph. passthrough_ts: an object convertible to a list of tf.Tensor. This list define all the "passthrough" tensors. A passthrough tensor is a tensor which goes directly from the input of the subgraph to it output, without any intermediate operations. All the non passthrough tensors are silently ignored. Raises: TypeError: if inside_ops cannot be converted to a list of tf.Operation or if passthrough_ts cannot be converted to a list of tf.Tensor. """ inside_ops = util.make_list_of_op(inside_ops) passthrough_ts = util.make_list_of_t(passthrough_ts) ops_and_ts = inside_ops + passthrough_ts if ops_and_ts: self._graph = util.get_unique_graph(ops_and_ts) else: self._graph = None self._ops = inside_ops # Compute inside and outside tensor inputs, outputs, insides = select.compute_boundary_ts(inside_ops) # Compute passthrough tensors, silently ignoring the non-passthrough ones. all_tensors = frozenset(inputs + outputs + list(insides)) self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors] # Set inputs and outputs. self._input_ts = inputs + self._passthrough_ts self._output_ts = outputs + self._passthrough_ts
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, control_outputs=True): """Do a forward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the forward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the consumers of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of tf.Operation whithin which the search is restricted. If within_ops is None, the search is performed within the whole graph. control_outputs: an object convertible to a control output dictionary (see function util.convert_to_control_outputs for more details). If the dictionary can be created, it will be used while walking the graph forward. Returns: A Python set of all the tf.Operation ahead of seed_ops. Raises: TypeError: if seed_ops or within_ops cannot be converted to a list of tf.Operation. """ if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return set() if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = get_consuming_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) control_outputs = util.convert_to_control_outputs(seed_ops, control_outputs) seed_ops = frozenset(seed_ops) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = set(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.outputs: for new_op in new_t.consumers(): if new_op not in result and is_within(new_op): new_wave.add(new_op) if control_outputs is not None and op in control_outputs: for new_op in control_outputs[op]: if new_op not in result and is_within(new_op): new_wave.add(new_op) result.update(new_wave) wave = new_wave if not inclusive: result.difference_update(seed_ops) return result
def get_generating_ops(ts): """Return all the generating ops of the tensors in ts. Args: ts: a list of tf.Tensor Returns: A list of all the generating tf.Operation of the tensors in ts. Raises: TypeError: if ts cannot be converted to a list of tf.Tensor. """ ts = util.make_list_of_t(ts, allow_graph=False) return [t.op for t in ts]
def get_generating_ops(ts): """Return all the generating ops of the tensors in ts. Args: ts: a list of tf.Tensor Returns: A list of all the generating tf.Operation of the tensors in ts. Raises: TypeError: if ts cannot be converted to a list of tf.Tensor. """ ts = util.make_list_of_t(ts, allow_graph=False) return [t.op for t in ts]
def get_consuming_ops(ts): """Return all the consuming ops of the tensors in ts. Args: ts: a list of tf.Tensor Returns: A list of all the consuming tf.Operation of the tensors in ts. Raises: TypeError: if ts cannot be converted to a list of tf.Tensor. """ ts = util.make_list_of_t(ts, allow_graph=False) ops = [] for t in ts: for op in t.consumers(): if op not in ops: ops.append(op) return ops
def get_consuming_ops(ts): """Return all the consuming ops of the tensors in ts. Args: ts: a list of tf.Tensor Returns: A list of all the consuming tf.Operation of the tensors in ts. Raises: TypeError: if ts cannot be converted to a list of tf.Tensor. """ ts = util.make_list_of_t(ts, allow_graph=False) ops = [] for t in ts: for op in t.consumers(): if op not in ops: ops.append(op) return ops
def __init__(self, inside_ops=(), passthrough_ts=()): """Create a subgraph containing the given ops and the "passthrough" tensors. Args: inside_ops: an object convertible to a list of `tf.Operation`. This list defines all the operations in the subgraph. passthrough_ts: an object convertible to a list of `tf.Tensor`. This list define all the "passthrough" tensors. A passthrough tensor is a tensor which goes directly from the input of the subgraph to it output, without any intermediate operations. All the non passthrough tensors are silently ignored. Raises: TypeError: if inside_ops cannot be converted to a list of `tf.Operation` or if `passthrough_ts` cannot be converted to a list of `tf.Tensor`. """ inside_ops = util.make_list_of_op(inside_ops) passthrough_ts = util.make_list_of_t(passthrough_ts) ops_and_ts = inside_ops + passthrough_ts if ops_and_ts: self._graph = util.get_unique_graph(ops_and_ts) self._ops = inside_ops # Compute inside and outside tensor inputs, outputs, insides = select.compute_boundary_ts(inside_ops) # Compute passthrough tensors, silently ignoring the non-passthrough ones. all_tensors = frozenset(inputs + outputs + list(insides)) self._passthrough_ts = [ t for t in passthrough_ts if t not in all_tensors ] # Set inputs and outputs. self._input_ts = inputs + self._passthrough_ts self._output_ts = outputs + self._passthrough_ts else: self._graph = None self._passthrough_ts = [] self._input_ts = [] self._output_ts = [] self._ops = []
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_outputs=None): """Do a forward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the forward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the consumers of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of `tf.Operation` within which the search is restricted. If `within_ops` is `None`, the search is performed within the whole graph. stop_at_ts: an iterable of tensors at which the graph walk stops. control_outputs: a `util.ControlOutputs` instance or None. If not `None`, it will be used while walking the graph forward. Returns: A Python set of all the `tf.Operation` ahead of `seed_ops`. Raises: TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of `tf.Operation`. """ _, control_outputs = check_cios(False, control_outputs) if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_consuming_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) seed_ops = frozenset(seed_ops) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.outputs: if new_t in stop_at_ts: continue for new_op in new_t.consumers(): if new_op not in result and is_within(new_op): new_wave.add(new_op) if control_outputs is not None: for new_op in control_outputs.get(op): if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result
def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None): """Reroute the end of the tensors in each pair (t0,t1) in ts0 x ts1. This function is the back-bone of the Graph-Editor. It is essentially a thin wrapper on top of the tf.Operation._update_input. Given a pair of tensor t0, t1 in ts0 x ts1, this function re-route the end of t0 and t1 in three possible ways: 1) The reroute mode is "a<->b" or "b<->a": the tensors' end are swapped. After this operation, the previous consumers of t0 are now consumers of t1 and vice-versa. 2) The reroute mode is "a->b": the tensors' end of t0 are re-routed to the tensors's end of t1 (which are left dangling). After this operation, the previous consumers of t0 are still consuming t0 but the previous consumers of t1 are not also consuming t0. The tensor t1 has no consumer. 3) The reroute mode is "b->a": this mode is the symmetric of the "a->b" mode. Note that this function is re-routing the end of two tensors, not the start. Re-routing the start of two tensors is not supported by this library. The reason for that is the following: TensorFlow, by design, creates a strong bond between an op and its output tensor. This Graph editor follows this design and treats an operation A and its generating tensors {t_i} as an entity which cannot be broken. In other words, an op cannot be detached from any of its output tensors, ever. But it is possible to detach an op from its input tensors, which is what this function concerns itself with. Warning: this function is directly manipulating the internals of the tf.Graph. Args: ts0: an object convertible to a list of `tf.Tensor`. ts1: an object convertible to a list of `tf.Tensor`. mode: what to do with those tensors: "a->b" or "b<->a" for swaping and "a->b" or "b->a" for one direction re-routing. can_modify: iterable of operations which can be modified. Any operation outside within_ops will be left untouched by this function. cannot_modify: iterable of operations which cannot be modified. Any operation within cannot_modify will be left untouched by this function. Returns: The number of individual modifications made by the function. Raises: TypeError: if `ts0` or `ts1` cannot be converted to a list of `tf.Tensor`. TypeError: if `can_modify` or `cannot_modify` is not `None` and cannot be converted to a list of `tf.Operation`. """ a2b, b2a = _RerouteMode.check(mode) ts0 = _util.make_list_of_t(ts0) ts1 = _util.make_list_of_t(ts1) _check_ts_compatibility(ts0, ts1) if cannot_modify is not None: cannot_modify = frozenset(_util.make_list_of_op(cannot_modify)) if can_modify is not None: can_modify = frozenset(_util.make_list_of_op(can_modify)) nb_update_inputs = 0 precomputed_consumers = [] # precompute consumers to avoid issue with repeated tensors: for t0, t1 in zip(ts0, ts1): consumers0 = set(t0.consumers()) consumers1 = set(t1.consumers()) precomputed_consumers.append((consumers0, consumers1)) for t0, t1, consumers in zip(ts0, ts1, precomputed_consumers): if t0 is t1: continue # Silently ignore identical tensors. consumers0, consumers1 = consumers if a2b: nb_update_inputs += _reroute_t(t0, t1, consumers1, can_modify, cannot_modify) if b2a: nb_update_inputs += _reroute_t(t1, t0, consumers0, can_modify, cannot_modify) return nb_update_inputs
def select_ts(*args, **kwargs): """Helper to select tensors. Args: *args: list of 1) regular expressions (compiled or not) or 2) (array of) tf.Tensor. tf.Operation instances are silently ignored. **kwargs: 'graph': tf.Graph in which to perform the regex query.This is required when using regex. 'positive_filter': an elem if selected only if positive_filter(elem) is True. This is optional. 'restrict_regex': a regular expression is ignored if it doesn't start with the substring "(?#ts)". Returns: list of tf.Tensor Raises: TypeError: if the optional keyword argument graph is not a tf.Graph or if an argument in args is not an (array of) tf.Tensor or an (array of) tf.Operation (silently ignored) or a string or a regular expression. ValueError: if one of the keyword arguments is unexpected or if a regular expression is used without passing a graph as a keyword argument. """ # get keywords arguments graph = None positive_filter = None restrict_regex = False for k, v in kwargs.iteritems(): if k == "graph": graph = v if graph is not None and not isinstance(graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got {}".format( type(graph))) elif k == "positive_filter": positive_filter = v elif k == "restrict_regex": restrict_regex = v else: raise ValueError("Wrong keywords argument: {}.".format(k)) ts = [] for arg in args: if _can_be_regex(arg): if graph is None: raise ValueError( "Use the keyword argument 'graph' to use regex.") regex = _make_regex(arg) if regex.pattern.startswith("(?#ops)"): continue if restrict_regex and not regex.pattern.startswith("(?#ts)"): continue ts_ = filter_ts_from_regex(graph, regex) for t_ in ts_: if t_ not in ts: if positive_filter is None or positive_filter(t_): ts.append(t_) else: ts_aux = util.make_list_of_t(arg, ignore_ops=True) if positive_filter is not None: ts_aux = [t for t in ts_aux if positive_filter(t)] ts_aux = [t for t in ts_aux if t not in ts] ts += ts_aux return ts
def select_ts(*args, **kwargs): """Helper to select tensors. Args: *args: list of 1) regular expressions (compiled or not) or 2) (array of) tf.Tensor. tf.Operation instances are silently ignored. **kwargs: 'graph': tf.Graph in which to perform the regex query.This is required when using regex. 'positive_filter': an elem if selected only if positive_filter(elem) is True. This is optional. 'restrict_regex': a regular expression is ignored if it doesn't start with the substring "(?#ts)". Returns: list of tf.Tensor Raises: TypeError: if the optional keyword argument graph is not a tf.Graph or if an argument in args is not an (array of) tf.Tensor or an (array of) tf.Operation (silently ignored) or a string or a regular expression. ValueError: if one of the keyword arguments is unexpected or if a regular expression is used without passing a graph as a keyword argument. """ # get keywords arguments graph = None positive_filter = None restrict_regex = False for k, v in kwargs.iteritems(): if k == "graph": graph = v if graph is not None and not isinstance(graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got {}".format(type(graph))) elif k == "positive_filter": positive_filter = v elif k == "restrict_regex": restrict_regex = v else: raise ValueError("Wrong keywords argument: {}.".format(k)) ts = [] for arg in args: if _can_be_regex(arg): if graph is None: raise ValueError("Use the keyword argument 'graph' to use regex.") regex = _make_regex(arg) if regex.pattern.startswith("(?#ops)"): continue if restrict_regex and not regex.pattern.startswith("(?#ts)"): continue ts_ = filter_ts_from_regex(graph, regex) for t_ in ts_: if t_ not in ts: if positive_filter is None or positive_filter(t_): ts.append(t_) else: ts_aux = util.make_list_of_t(arg, ignore_ops=True) if positive_filter is not None: ts_aux = [t for t in ts_aux if positive_filter(t)] ts_aux = [t for t in ts_aux if t not in ts] ts += ts_aux return ts
def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None): """Reroute the end of the tensors in each pair (t0,t1) in ts0 x ts1. This function is the back-bone of the Graph-Editor. It is essentially a thin wrapper on top of the tf.Operation._update_input. Given a pair of tensor t0, t1 in ts0 x ts1, this function re-route the end of t0 and t1 in three possible ways: 1) The reroute mode is "a<->b" or "b<->a": the tensors' end are swapped. After this operation, the previous consumers of t0 are now consumers of t1 and vice-versa. 2) The reroute mode is "a->b": the tensors' end of t0 are re-routed to the tensors's end of t1 (which are left dangling). After this operation, the previous consumers of t0 are still consuming t0 but the previous consumers of t1 are not also consuming t0. The tensor t1 has no consumer. 3) The reroute mode is "b->a": this mode is the symmetric of the "a->b" mode. Note that this function is re-routing the end of two tensors, not the start. Re-routing the start of two tensors is not supported by this library. The reason for that is the following: TensorFlow, by design, creates a strong bond between an op and its output tensor. This Graph editor follows this design and treats an operation A and its generating tensors {t_i} as an entity which cannot be broken. In other words, an op cannot be detached from any of its output tensors, ever. But it is possible to detach an op from its input tensors, which is what this function concerns itself with. Warning: this function is directly manipulating the internals of the tf.Graph. Args: ts0: an object convertible to a list of `tf.Tensor`. ts1: an object convertible to a list of `tf.Tensor`. mode: what to do with those tensors: "a->b" or "b<->a" for swaping and "a->b" or "b->a" for one direction re-routing. can_modify: iterable of operations which can be modified. Any operation outside within_ops will be left untouched by this function. cannot_modify: iterable of operations which cannot be modified. Any operation within cannot_modify will be left untouched by this function. Returns: The number of individual modifications made by the function. Raises: TypeError: if `ts0` or `ts1` cannot be converted to a list of `tf.Tensor`. TypeError: if `can_modify` or `cannot_modify` is not `None` and cannot be converted to a list of `tf.Operation`. """ a2b, b2a = _RerouteMode.check(mode) ts0 = _util.make_list_of_t(ts0) ts1 = _util.make_list_of_t(ts1) _check_ts_compatibility(ts0, ts1) if cannot_modify is not None: cannot_modify = frozenset(_util.make_list_of_op(cannot_modify)) if can_modify is not None: can_modify = frozenset(_util.make_list_of_op(can_modify)) nb_update_inputs = 0 precomputed_consumers = [] # precompute consumers to avoid issue with repeated tensors: for t0, t1 in zip(ts0, ts1): consumers0 = set(t0.consumers()) consumers1 = set(t1.consumers()) precomputed_consumers.append((consumers0, consumers1)) for t0, t1, consumers in zip(ts0, ts1, precomputed_consumers): if t0 is t1: continue # Silently ignore identical tensors. consumers0, consumers1 = consumers if a2b: nb_update_inputs += _reroute_t(t0, t1, consumers1, can_modify, cannot_modify) if b2a: nb_update_inputs += _reroute_t(t1, t0, consumers0, can_modify, cannot_modify) return nb_update_inputs
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_outputs=None): """Do a forward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the forward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the consumers of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of `tf.Operation` within which the search is restricted. If `within_ops` is `None`, the search is performed within the whole graph. stop_at_ts: an iterable of tensors at which the graph walk stops. control_outputs: a `util.ControlOutputs` instance or None. If not `None`, it will be used while walking the graph forward. Returns: A Python set of all the `tf.Operation` ahead of `seed_ops`. Raises: TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of `tf.Operation`. """ _, control_outputs = check_cios(False, control_outputs) if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_consuming_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) seed_ops = frozenset(seed_ops) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.outputs: if new_t in stop_at_ts: continue for new_op in new_t.consumers(): if new_op not in result and is_within(new_op): new_wave.add(new_op) if control_outputs is not None: for new_op in control_outputs.get(op): if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result