def get_within_boundary_ops(ops, seed_ops, boundary_ops=(), inclusive=True, control_inputs=False, control_outputs=None, control_ios=None): """Return all the `tf.Operation` within the given boundary. Args: ops: an object convertible to a list of `tf.Operation`. those ops define the set in which to perform the operation (if a `tf.Graph` is given, it will be converted to the list of all its operations). seed_ops: the operations from which to start expanding. boundary_ops: the ops forming the boundary. inclusive: if `True`, the result will also include the boundary ops. control_inputs: A boolean indicating whether control inputs are enabled. control_outputs: An instance of `util.ControlOutputs` or `None`. If not `None`, control outputs are enabled. control_ios: An instance of `util.ControlOutputs` or `None`. If not `None`, both control inputs and control outputs are enabled. This is equivalent to set control_inputs to True and control_outputs to the `util.ControlOutputs` instance. Returns: All the `tf.Operation` surrounding the given ops. Raises: TypeError: if `ops` or `seed_ops` cannot be converted to a list of `tf.Operation`. ValueError: if the boundary is intersecting with the seeds. """ control_inputs, control_outputs = check_cios(control_inputs, control_outputs, control_ios) ops = util.make_list_of_op(ops) seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) boundary_ops = set(util.make_list_of_op(boundary_ops)) res = set(seed_ops) if boundary_ops & res: raise ValueError("Boundary is intersecting with the seeds.") wave = set(seed_ops) while wave: new_wave = set() ops_io = get_ops_ios(wave, control_inputs, control_outputs) for op in ops_io: if op in res: continue if op in boundary_ops: if inclusive: res.add(op) else: new_wave.add(op) res.update(new_wave) wave = new_wave return [op for op in ops if op in res]
def get_ops_ios(ops, control_inputs=False, control_outputs=None, control_ios=None): """Return all the `tf.Operation` which are connected to an op in ops. Args: ops: an object convertible to a list of `tf.Operation`. control_inputs: A boolean indicating whether control inputs are enabled. control_outputs: An instance of `util.ControlOutputs` or `None`. If not `None`, control outputs are enabled. control_ios: An instance of `util.ControlOutputs` or `None`. If not `None`, both control inputs and control outputs are enabled. This is equivalent to set `control_inputs` to `True` and `control_outputs` to the `util.ControlOutputs` instance. Returns: All the `tf.Operation` surrounding the given ops. Raises: TypeError: if `ops` cannot be converted to a list of `tf.Operation`. """ control_inputs, control_outputs = check_cios(control_inputs, control_outputs, control_ios) ops = util.make_list_of_op(ops) res = [] for op in ops: util.concatenate_unique(res, [t.op for t in op.inputs]) for t in op.outputs: util.concatenate_unique(res, t.consumers()) if control_outputs is not None: util.concatenate_unique(res, control_outputs.get(op)) if control_inputs: util.concatenate_unique(res, op.control_inputs) return res
def compute_boundary_ts(ops): """Compute the tensors at the boundary of a set of ops. This function looks at all the tensors connected to the given ops (in/out) and classify them into three categories: 1) input tensors: tensors whose generating operation is not in ops. 2) output tensors: tensors whose consumer operations are not in ops 3) inside tensors: tensors which are neither input nor output tensors. Note that a tensor can be both an inside tensor and an output tensor if it is consumed by operations both outside and inside of `ops`. Args: ops: an object convertible to a list of `pge.Node` Returns: A tuple `(outside_input_ts, outside_output_ts, inside_ts)` where: `outside_input_ts` is a Python list of input tensors; `outside_output_ts` is a python list of output tensors; `inside_ts` is a python list of inside tensors. Since a tensor can be both an inside tensor and an output tensor, `outside_output_ts` and `inside_ts` might intersect. Raises: TypeError: if ops cannot be converted to a list of tf.Operation. """ ops = util.make_list_of_op(ops) input_ts = _get_input_ts(ops) output_ts = _get_output_ts(ops) output_ts_set = frozenset(output_ts) ops_set = frozenset(ops) # Compute inside tensors. inside_ts = [] only_inside_ts = [] for t in input_ts: # Skip if the input tensor is not also an output tensor. if t not in output_ts_set: continue # Mark as "inside". inside_ts.append(t) # Mark as "only inside" if the tensor is not both inside and output. consumers = frozenset(t.consumers) if consumers - ops_set: continue only_inside_ts.append(t) inside_ts_set = frozenset(inside_ts) only_inside_ts_set = frozenset(only_inside_ts) outside_output_ts = [t for t in output_ts if t not in only_inside_ts_set] outside_input_ts = [t for t in input_ts if t not in inside_ts_set] return outside_input_ts, outside_output_ts, inside_ts
def _get_output_ts(ops): """Compute the list of unique output tensors of all the op in ops. Args: ops: an object convertible to a list of tf.Operation. Returns: The list of unique output tensors of all the op in ops. Raises: TypeError: if ops cannot be converted to a list of tf.Operation. """ ops = util.make_list_of_op(ops) ts = [] for op in ops: ts += op.outputs return ts
def filter_ops_from_regex(ops, regex): """Get all the operations that match the given regex. Args: ops: an object convertible to a list of `tf.Operation`. regex: a regular expression matching the operation's name. For example, `"^foo(/.*)?$"` will match all the operations in the "foo" scope. Returns: A list of `tf.Operation`. Raises: TypeError: if ops cannot be converted to a list of `tf.Operation`. """ ops = util.make_list_of_op(ops) regex_obj = make_regex(regex) return filter_ops(ops, lambda op: regex_obj.search(op.name))
def filter_ops(ops, positive_filter): """Get the ops passing the given filter. Args: ops: an object convertible to a list of tf.Operation. positive_filter: a function deciding where to keep an operation or not. If True, all the operations are returned. Returns: A list of selected tf.Operation. Raises: TypeError: if ops cannot be converted to a list of tf.Operation. """ ops = util.make_list_of_op(ops) if positive_filter is not True: # pylint: disable=g-explicit-bool-comparison ops = [op for op in ops if positive_filter(op)] return ops
def filter_ts_from_regex(ops, regex): r"""Get all the tensors linked to ops that match the given regex. Args: ops: an object convertible to a list of tf.Operation. regex: a regular expression matching the tensors' name. For example, "^foo(/.*)?:\d+$" will match all the tensors in the "foo" scope. Returns: A list of tf.Tensor. Raises: TypeError: if ops cannot be converted to a list of tf.Operation. """ ops = util.make_list_of_op(ops) regex_obj = make_regex(regex) return filter_ts(ops, positive_filter=lambda op: regex_obj.search(op.name))
def filter_ts(ops, positive_filter): """Get all the tensors which are input or output of an op in ops. Args: ops: an object convertible to a list of `tf.Operation`. positive_filter: a function deciding whether to keep a tensor or not. If `True`, all the tensors are returned. Returns: A list of `tf.Tensor`. Raises: TypeError: if ops cannot be converted to a list of `tf.Operation`. """ ops = util.make_list_of_op(ops) ts = _get_input_ts(ops) util.concatenate_unique(ts, _get_output_ts(ops)) if positive_filter is not True: ts = [t for t in ts if positive_filter(t)] return ts
def _get_input_ts(ops): """Compute the list of unique input tensors of all the op in ops. Args: ops: an object convertible to a list of `tf.Operation`. Returns: The list of unique input tensors of all the op in ops. Raises: TypeError: if ops cannot be converted to a list of `tf.Operation`. """ ops = util.make_list_of_op(ops) ts = [] ts_set = set() for op in ops: for t in op.inputs: if t not in ts_set: ts.append(t) ts_set.add(t) return ts
def __init__(self, inside_ops=(), passthrough_ts=()): """Create a subgraph containing the given ops and the "passthrough" tensors. Args: inside_ops: an object convertible to a list of `tf.Operation`. This list defines all the operations in the subgraph. passthrough_ts: an object convertible to a list of `tf.Tensor`. This list define all the "passthrough" tensors. A passthrough tensor is a tensor which goes directly from the input of the subgraph to it output, without any intermediate operations. All the non passthrough tensors are silently ignored. Raises: TypeError: if inside_ops cannot be converted to a list of `tf.Operation` or if `passthrough_ts` cannot be converted to a list of `tf.Tensor`. """ inside_ops = util.make_list_of_op(inside_ops) passthrough_ts = util.make_list_of_t(passthrough_ts) ops_and_ts = inside_ops + passthrough_ts if ops_and_ts: self._graph = util.get_unique_graph(ops_and_ts) self._ops = inside_ops # Compute inside and outside tensor inputs, outputs, insides = select.compute_boundary_ts(inside_ops) # Compute passthrough tensors, silently ignoring the non-passthrough ones. all_tensors = frozenset(inputs + outputs + list(insides)) self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors] # Set inputs and outputs. self._input_ts = inputs + self._passthrough_ts self._output_ts = outputs + self._passthrough_ts else: self._graph = None self._passthrough_ts = [] self._input_ts = [] self._output_ts = [] self._ops = []
def select_ops(*args, **kwargs): """Helper to select operations. Args: *args: list of 1) regular expressions (compiled or not) or 2) (array of) `tf.Operation`. `tf.Tensor` instances are silently ignored. **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is required when using regex. 'positive_filter': an elem if selected only if `positive_filter(elem)` is `True`. This is optional. 'restrict_ops_regex': a regular expression is ignored if it doesn't start with the substring "(?#ops)". Returns: A list of `tf.Operation`. Raises: TypeError: if the optional keyword argument graph is not a `tf.Graph` or if an argument in args is not an (array of) `tf.Operation` or an (array of) `tf.Tensor` (silently ignored) or a string or a regular expression. ValueError: if one of the keyword arguments is unexpected or if a regular expression is used without passing a graph as a keyword argument. """ # get keywords arguments graph = None positive_filter = None restrict_ops_regex = False for k, v in iteritems(kwargs): if k == "graph": graph = v if graph is not None and not isinstance(graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format( type(graph))) elif k == "positive_filter": positive_filter = v elif k == "restrict_ops_regex": restrict_ops_regex = v elif k == "restrict_ts_regex": pass else: raise ValueError("Wrong keywords argument: {}.".format(k)) ops = [] for arg in args: if can_be_regex(arg): if graph is None: raise ValueError( "Use the keyword argument 'graph' to use regex.") regex = make_regex(arg) if regex.pattern.startswith("(?#ts)"): continue if restrict_ops_regex and not regex.pattern.startswith("(?#ops)"): continue ops_ = filter_ops_from_regex(graph, regex) for op_ in ops_: if op_ not in ops: if positive_filter is None or positive_filter(op_): ops.append(op_) else: ops_aux = util.make_list_of_op(arg, ignore_ts=True) if positive_filter is not None: ops_aux = [op for op in ops_aux if positive_filter(op)] ops_aux = [op for op in ops_aux if op not in ops] ops += ops_aux return ops
def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, within_ops_fn=None, stop_at_ts=(), control_inputs=False): """Do a backward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the backward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the generators of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of `tf.Operation` within which the search is restricted. If `within_ops` is `None`, the search is performed within the whole graph. within_ops_fn: if provided, a function on ops that should return True iff the op is within the graph traversal. This can be used along within_ops, in which case an op is within if it is also in within_ops. stop_at_ts: an iterable of tensors at which the graph walk stops. control_inputs: if True, control inputs will be used while moving backward. Returns: A Python set of all the `tf.Operation` behind `seed_ops`. Raises: TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of `tf.Operation`. """ if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_generating_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) seed_ops = frozenset(util.make_list_of_op(seed_ops)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return (within_ops is None or op in within_ops) and (within_ops_fn is None or within_ops_fn(op)) result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.inputs: if new_t in stop_at_ts: continue if new_t.op not in result and is_within(new_t.op): new_wave.add(new_t.op) if control_inputs: for new_op in op.control_inputs: if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result