def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_inputs=False): """Do a backward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the backward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the generators of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of tf.Operation whithin which the search is restricted. If within_ops is None, the search is performed within the whole graph. stop_at_ts: an iterable of tensors at which the graph walk stops. control_inputs: if True, control inputs will be used while moving backward. Returns: A Python set of all the tf.Operation behind seed_ops. Raises: TypeError: if seed_ops or within_ops cannot be converted to a list of tf.Operation. """ if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_generating_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) seed_ops = frozenset(util.make_list_of_op(seed_ops)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.inputs: if new_t in stop_at_ts: continue if new_t.op not in result and is_within(new_t.op): new_wave.add(new_t.op) if control_inputs: for new_op in op.control_inputs: if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result
def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_inputs=False): """Do a backward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the backward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the generators of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of tf.Operation whithin which the search is restricted. If within_ops is None, the search is performed within the whole graph. stop_at_ts: an iterable of tensors at which the graph walk stops. control_inputs: if True, control inputs will be used while moving backward. Returns: A Python set of all the tf.Operation behind seed_ops. Raises: TypeError: if seed_ops or within_ops cannot be converted to a list of tf.Operation. """ if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_generating_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) seed_ops = frozenset(util.make_list_of_op(seed_ops)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.inputs: if new_t in stop_at_ts: continue if new_t.op not in result and is_within(new_t.op): new_wave.add(new_t.op) if control_inputs: for new_op in op.control_inputs: if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result
def get_ops_ios(ops, control_inputs=False, control_outputs=None, control_ios=None): """Return all the tf.Operation which are connected to an op in ops. Args: ops: an object convertible to a list of tf.Operation. control_inputs: A boolean indicating whether control inputs are enabled. control_outputs: An instance of util.ControlOutputs or None. If not None, control outputs are enabled. control_ios: An instance of util.ControlOutputs or None. If not None, both control inputs and control outputs are enabled. This is equivalent to set control_inputs to True and control_outputs to the util.ControlOutputs instance. Returns: All the tf.Operation surrounding the given ops. Raises: TypeError: if ops cannot be converted to a list of tf.Operation. """ control_inputs, control_outputs = check_cios(control_inputs, control_outputs, control_ios) ops = util.make_list_of_op(ops) res = [] for op in ops: util.concatenate_unique(res, [t.op for t in op.inputs]) for t in op.outputs: util.concatenate_unique(res, t.consumers()) if control_outputs is not None: util.concatenate_unique(res, control_outputs.get(op)) if control_inputs: util.concatenate_unique(res, op.control_inputs) return res
def consumers(self): """Return a Python set of all the consumers of this subgraph view. A consumer of a subgraph view is a tf.Operation which is a consumer of one of the output tensors and is not in the subgraph. Returns: A list of `tf.Operation` which are the consumers of this subgraph view. """ ops_set = frozenset(self._ops) res = [] for output in self._output_ts: consumers = [op for op in output.consumers() if op not in ops_set] util.concatenate_unique(res, consumers) return res
def consumers(self): """Return a Python set of all the consumers of this subgraph view. A consumer of a subgraph view is a tf.Operation which is a consumer of one of the output tensors and is not in the subgraph. Returns: A list of `tf.Operation` which are the consumers of this subgraph view. """ ops_set = frozenset(self._ops) res = [] for output in self._output_ts: consumers = [op for op in output.consumers() if op not in ops_set] util.concatenate_unique(res, consumers) return res
def filter_ts(ops, positive_filter): """Get all the tensors which are input or output of an op in ops. Args: ops: an object convertible to a list of tf.Operation. positive_filter: a function deciding whether to keep a tensor or not. If True, all the tensors are returned. Returns: A list of tf.Tensor. Raises: TypeError: if ops cannot be converted to a list of tf.Operation. """ ops = util.make_list_of_op(ops) ts = _get_input_ts(ops) util.concatenate_unique(ts, _get_output_ts(ops)) if positive_filter is not True: ts = [t for t in ts if positive_filter(t)] return ts
def get_walks_union_ops(forward_seed_ops, backward_seed_ops, forward_inclusive=True, backward_inclusive=True, within_ops=None, within_ops_fn=None, control_inputs=False, control_outputs=None, control_ios=None): """Return the union of a forward and a backward walk. Args: forward_seed_ops: an iterable of operations from which the forward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the consumers of those tensors. backward_seed_ops: an iterable of operations from which the backward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the generators of those tensors. forward_inclusive: if True the given forward_seed_ops are also part of the resulting set. backward_inclusive: if True the given backward_seed_ops are also part of the resulting set. within_ops: restrict the search within those operations. If within_ops is None, the search is done within the whole graph. within_ops_fn: if provided, a function on ops that should return True iff the op is within the graph traversal. This can be used along within_ops, in which case an op is within if it is also in within_ops. control_inputs: A boolean indicating whether control inputs are enabled. control_outputs: An instance of util.ControlOutputs or None. If not None, control outputs are enabled. control_ios: An instance of util.ControlOutputs or None. If not None, both control inputs and control outputs are enabled. This is equivalent to set control_inputs to True and control_outputs to the util.ControlOutputs instance. Returns: A Python set of all the tf.Operation in the union of a forward and a backward walk. Raises: TypeError: if forward_seed_ops or backward_seed_ops or within_ops cannot be converted to a list of tf.Operation. """ control_inputs, control_outputs = check_cios(control_inputs, control_outputs, control_ios) forward_ops = get_forward_walk_ops( forward_seed_ops, inclusive=forward_inclusive, within_ops=within_ops, within_ops_fn=within_ops_fn, control_outputs=control_outputs) backward_ops = get_backward_walk_ops( backward_seed_ops, inclusive=backward_inclusive, within_ops=within_ops, within_ops_fn=within_ops_fn, control_inputs=control_inputs) return util.concatenate_unique(forward_ops, backward_ops)
def get_walks_union_ops(forward_seed_ops, backward_seed_ops, forward_inclusive=True, backward_inclusive=True, within_ops=None, within_ops_fn=None, control_inputs=False, control_outputs=None, control_ios=None): """Return the union of a forward and a backward walk. Args: forward_seed_ops: an iterable of operations from which the forward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the consumers of those tensors. backward_seed_ops: an iterable of operations from which the backward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the generators of those tensors. forward_inclusive: if True the given forward_seed_ops are also part of the resulting set. backward_inclusive: if True the given backward_seed_ops are also part of the resulting set. within_ops: restrict the search within those operations. If within_ops is None, the search is done within the whole graph. within_ops_fn: if provided, a function on ops that should return True iff the op is within the graph traversal. This can be used along within_ops, in which case an op is within if it is also in within_ops. control_inputs: A boolean indicating whether control inputs are enabled. control_outputs: An instance of util.ControlOutputs or None. If not None, control outputs are enabled. control_ios: An instance of util.ControlOutputs or None. If not None, both control inputs and control outputs are enabled. This is equivalent to set control_inputs to True and control_outputs to the util.ControlOutputs instance. Returns: A Python set of all the tf.Operation in the union of a forward and a backward walk. Raises: TypeError: if forward_seed_ops or backward_seed_ops or within_ops cannot be converted to a list of tf.Operation. """ control_inputs, control_outputs = check_cios(control_inputs, control_outputs, control_ios) forward_ops = get_forward_walk_ops( forward_seed_ops, inclusive=forward_inclusive, within_ops=within_ops, within_ops_fn=within_ops_fn, control_outputs=control_outputs) backward_ops = get_backward_walk_ops( backward_seed_ops, inclusive=backward_inclusive, within_ops=within_ops, within_ops_fn=within_ops_fn, control_inputs=control_inputs) return util.concatenate_unique(forward_ops, backward_ops)
def consumers(self): """Return a Python set of all the consumers of this subgraph view.""" res = [] for output in self._output_ts: util.concatenate_unique(res, output.consumers()) return res
def _remap_outputs_make_unique(self): """Remap the outputs in place so that all the tensors appears only once.""" output_ts = list(self._output_ts) self._output_ts = [] util.concatenate_unique(self._output_ts, output_ts)
def _remap_outputs_make_unique(self): """Remap the outputs in place so that all the tensors appears only once.""" output_ts = list(self._output_ts) self._output_ts = [] util.concatenate_unique(self._output_ts, output_ts)
def consumers(self): """Return a Python set of all the consumers of this subgraph view.""" res = [] for output in self._output_ts: util.concatenate_unique(res, output.consumers()) return res
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_outputs=None): """Do a forward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the forward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the consumers of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of `tf.Operation` within which the search is restricted. If `within_ops` is `None`, the search is performed within the whole graph. stop_at_ts: an iterable of tensors at which the graph walk stops. control_outputs: a `util.ControlOutputs` instance or None. If not `None`, it will be used while walking the graph forward. Returns: A Python set of all the `tf.Operation` ahead of `seed_ops`. Raises: TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of `tf.Operation`. """ _, control_outputs = check_cios(False, control_outputs) if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_consuming_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) seed_ops = frozenset(seed_ops) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.outputs: if new_t in stop_at_ts: continue for new_op in new_t.consumers(): if new_op not in result and is_within(new_op): new_wave.add(new_op) if control_outputs is not None: for new_op in control_outputs.get(op): if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_outputs=None): """Do a forward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the forward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the consumers of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of `tf.Operation` within which the search is restricted. If `within_ops` is `None`, the search is performed within the whole graph. stop_at_ts: an iterable of tensors at which the graph walk stops. control_outputs: a `util.ControlOutputs` instance or None. If not `None`, it will be used while walking the graph forward. Returns: A Python set of all the `tf.Operation` ahead of `seed_ops`. Raises: TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of `tf.Operation`. """ _, control_outputs = check_cios(False, control_outputs) if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return [] if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = util.get_consuming_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) seed_ops = frozenset(seed_ops) stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = list(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.outputs: if new_t in stop_at_ts: continue for new_op in new_t.consumers(): if new_op not in result and is_within(new_op): new_wave.add(new_op) if control_outputs is not None: for new_op in control_outputs.get(op): if new_op not in result and is_within(new_op): new_wave.add(new_op) util.concatenate_unique(result, new_wave) wave = new_wave if not inclusive: result = [op for op in result if op not in seed_ops] return result