def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, control_outputs=True): """Do a forward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the forward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the consumers of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of tf.Operation whithin which the search is restricted. If within_ops is None, the search is performed within the whole graph. control_outputs: an object convertible to a control output dictionary (see function util.convert_to_control_outputs for more details). If the dictionary can be created, it will be used while walking the graph forward. Returns: A Python set of all the tf.Operation ahead of seed_ops. Raises: TypeError: if seed_ops or within_ops cannot be converted to a list of tf.Operation. """ if not util.is_iterable(seed_ops): seed_ops = [seed_ops] if not seed_ops: return set() if isinstance(seed_ops[0], tf_ops.Tensor): ts = util.make_list_of_t(seed_ops, allow_graph=False) seed_ops = get_consuming_ops(ts) else: seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) control_outputs = util.convert_to_control_outputs(seed_ops, control_outputs) seed_ops = frozenset(seed_ops) if within_ops: within_ops = util.make_list_of_op(within_ops, allow_graph=False) within_ops = frozenset(within_ops) seed_ops &= within_ops def is_within(op): return within_ops is None or op in within_ops result = set(seed_ops) wave = set(seed_ops) while wave: new_wave = set() for op in wave: for new_t in op.outputs: for new_op in new_t.consumers(): if new_op not in result and is_within(new_op): new_wave.add(new_op) if control_outputs is not None and op in control_outputs: for new_op in control_outputs[op]: if new_op not in result and is_within(new_op): new_wave.add(new_op) result.update(new_wave) wave = new_wave if not inclusive: result.difference_update(seed_ops) return result
def get_within_boundary_ops(ops, seed_ops, boundary_ops, inclusive=True, control_outputs=True): """Return all the tf.Operation within the given boundary. Args: ops: an object convertible to a list of tf.Operation. those ops define the set in which to perform the operation (if a tf.Graph is given, it will be converted to the list of all its operations). seed_ops: the operations from which to start expanding. boundary_ops: the ops forming the boundary. inclusive: if True, the result will also include the boundary ops. control_outputs: an object convertible to a control output dictionary (or None). If the dictionary can be created, it will be used while expanding. Returns: All the tf.Operation surrounding the given ops. Raises: TypeError: if ops or seed_ops cannot be converted to a list of tf.Operation. ValueError: if the boundary is intersecting with the seeds. """ ops = util.make_list_of_op(ops) control_outputs = util.convert_to_control_outputs(ops, control_outputs) seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) boundary_ops = set(util.make_list_of_op(boundary_ops)) res = set(seed_ops) if boundary_ops & res: raise ValueError("Boundary is intersecting with the seeds.") wave = set(seed_ops) while wave: new_wave = set() ops_io = get_ops_ios(wave, control_outputs) for op in ops_io: if op in res: continue if op in boundary_ops: if inclusive: res.add(op) else: new_wave.add(op) res.update(new_wave) wave = new_wave return res
def get_ops_ios(ops, control_outputs=True): """Return all the tf.Operation which are connected to an op in ops. Args: ops: an object convertible to a list of tf.Operation. control_outputs: an object convertible to a control output dictionary (or None). If the dictionary can be created, it will be used to determine the surrounding ops (in addition to the regular inputs and outputs). Returns: All the tf.Operation surrounding the given ops. Raises: TypeError: if ops cannot be converted to a list of tf.Operation. """ control_outputs = util.convert_to_control_outputs(ops, control_outputs) ops = util.make_list_of_op(ops) res = set() for op in ops: res.update([t.op for t in op.inputs]) for t in op.outputs: res.update(t.consumers()) if control_outputs is not None and op in control_outputs: res.update(control_outputs[op]) return res