def connect(sgv0, sgv1, disconnect_first=False): """Connect the outputs of sgv0 to the inputs of sgv1. Args: sgv0: the first subgraph to have its outputs swapped. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. Note that sgv0 is modified in place. sgv1: the second subgraph to have its outputs swapped. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. Note that sgv1 is modified in place. disconnect_first: if True the current outputs of sgv0 are disconnected. Returns: A tuple `(sgv0, sgv1)` of the now connected subgraphs. Raises: StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv0 = subgraph.make_view(sgv0) sgv1 = subgraph.make_view(sgv1) util.check_graphs(sgv0, sgv1) if disconnect_first: detach_outputs(sgv0) sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs) reroute.reroute_inputs(sgv0_outputs, sgv1) return sgv0, sgv1
def _reroute_sgv_inputs(sgv0, sgv1, mode): """Re-route all the inputs of two subgraphs. Args: sgv0: the first subgraph to have its inputs swapped. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. sgv1: the second subgraph to have its inputs swapped. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. mode: reroute mode, see _reroute_ts(...). Returns: A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped. Note that the function argument sgv0 and sgv1 are also modified in place. Raises: StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv0 = _subgraph.make_view(sgv0) sgv1 = _subgraph.make_view(sgv1) _util.check_graphs(sgv0, sgv1) can_modify = sgv0.ops + sgv1.ops # also allow consumers of passthrough to be modified: can_modify += _util.get_consuming_ops(sgv0.passthroughs) can_modify += _util.get_consuming_ops(sgv1.passthroughs) _reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify) _reroute_sgv_remap(sgv0, sgv1, mode) return sgv0, sgv1
def detach_inputs(sgv, control_inputs=False): """Detach the inputs of a subgraph view. Args: sgv: the subgraph view to be detached. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. Note that sgv is modified in place. control_inputs: if True control_inputs are also detached. Returns: A tuple `(sgv, input_placeholders)` where `sgv` is a new subgraph view of the detached subgraph; `input_placeholders` is a list of the created input placeholders. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) with sgv.graph.as_default(): input_placeholders = [ tf_array_ops.placeholder(dtype=input_t.dtype, name=util.placeholder_name(input_t)) for input_t in sgv.inputs ] reroute.swap_inputs(sgv, input_placeholders) if control_inputs: detach_control_inputs(sgv) return sgv, input_placeholders
def detach_control_inputs(sgv): """Detach all the external control inputs of the subgraph sgv. Args: sgv: the subgraph view to be detached. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) for op in sgv.ops: cops = [cop for cop in op.control_inputs if cop not in sgv.ops] reroute.remove_control_inputs(op, cops)
def copy_with_input_replacements(sgv, replacement_ts, dst_graph=None, dst_scope="", src_scope="", reuse_dst_scope=False): """Copy a subgraph, replacing some of its inputs. Note a replacement only happens if the tensor to be replaced is an input of the given subgraph. The inputs of a subgraph can be queried using sgv.inputs. Args: sgv: the source subgraph-view. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. replacement_ts: dictionary mapping from original tensors to the replaced one. dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope. reuse_dst_scope: if True the dst_scope is re-used if it already exists. Otherwise, the scope is given a unique name based on the one given by appending an underscore followed by a digit (default). Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; `info` is an instance of TransformerInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: TypeError: if dst_graph is not a tf.Graph. StandardError: if sgv cannot be converted to a SubGraphView using the same rules as the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) if dst_graph is None: dst_graph = sgv.graph if not isinstance(dst_graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) copier = Transformer() # Replace tensor if possible. def replace_t_with_replacement_handler(info, t): if t in replacement_ts: return replacement_ts[t] return keep_t_if_possible_handler(info, t) copier.transform_external_input_handler = replace_t_with_replacement_handler return copier( sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)
def __call__(self, sgv, dst_graph, dst_scope, src_scope="", reuse_dst_scope=False): """Execute the transformation. Args: sgv: the source subgraph-view. dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope, which specify the path from which the relative path of the transformed nodes are computed. For instance, if src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a relative path of x/y and will be transformed into b/x/y. reuse_dst_scope: if True the dst_scope is re-used if it already exists. Otherwise, the scope is given a unique name based on the one given by appending an underscore followed by a digit (default). Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; `info` is an instance of TransformerInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: ValueError: if the arguments are invalid. """ sgv = subgraph.make_view(sgv) if not isinstance(dst_graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) src_scope = util.scope_finalize(src_scope) dst_scope = util.scope_finalize(dst_scope) # Potentially create new scope if reuse_dst_scope is False if dst_scope and not reuse_dst_scope: dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1])) # Create temporary info used during this transform call info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope) self._copy_ops(info) self._finalize_cycles(info) self._connect_control_inputs(info) # Compute information about the transformation res_info = TransformerInfo(info) sgv_ = self._transform_sgv(info, sgv) return sgv_, res_info
def _reroute_sgv_outputs(sgv0, sgv1, mode): """Re-route all the outputs of two operations. Args: sgv0: the first subgraph to have its outputs swapped. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. sgv1: the second subgraph to have its outputs swapped. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. mode: reroute mode, see _reroute_ts(...). Returns: A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped. Note that the function argument sgv0 and sgv1 are also modified in place. Raises: StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv0 = _subgraph.make_view(sgv0) sgv1 = _subgraph.make_view(sgv1) _util.check_graphs(sgv0, sgv1) cannot_modify = sgv0.ops + sgv1.ops _reroute_ts(sgv0.outputs, sgv1.outputs, mode, cannot_modify=cannot_modify) return sgv0, sgv1
def detach_control_outputs(sgv, control_outputs): """Detach all the external control outputs of the subgraph sgv. Args: sgv: the subgraph view to be detached. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. control_outputs: a util.ControlOutputs instance. """ if not isinstance(control_outputs, util.ControlOutputs): raise TypeError("Expected a util.ControlOutputs, got: %s" % type(control_outputs)) control_outputs.update() sgv = subgraph.make_view(sgv) for op in sgv.ops: for cop in control_outputs.get(op): if cop not in sgv.ops: reroute.remove_control_inputs(cop, op)
def bypass(sgv): """Bypass the given subgraph by connecting its inputs to its outputs. Args: sgv: the subgraph view to be bypassed. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. Note that sgv is modified in place. Returns: A tuple `(sgv, detached_inputs)` where: `sgv` is a new subgraph view of the bypassed subgraph; `detached_inputs` is a list of the created input placeholders. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers sgv = subgraph.make_view(sgv) sgv_inputs = list(sgv.inputs) sgv, detached_inputs = detach_inputs(sgv) reroute.reroute_ts(sgv_inputs, sgv.outputs) return sgv, detached_inputs
def detach_outputs(sgv, control_outputs=None): """Detach the output of a subgraph view. Args: sgv: the subgraph view to be detached. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. Note that sgv is modified in place. control_outputs: a util.ControlOutputs instance or None. If not None the control outputs are also detached. Returns: A tuple `(sgv, output_placeholders)` where `sgv` is a new subgraph view of the detached subgraph; `output_placeholders` is a list of the created output placeholders. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) # only select outputs with consumers sgv_ = sgv.remap_outputs([ output_id for output_id, output_t in enumerate(sgv.outputs) if output_t.consumers() ]) # create consumer subgraph and remap consumers_sgv = subgraph.SubGraphView(sgv_.consumers()) consumers_sgv = consumers_sgv.remap_inputs([ input_id for input_id, input_t in enumerate(consumers_sgv.inputs) if input_t in sgv_.outputs ]) with sgv_.graph.as_default(): output_placeholders = [ util.make_placeholder_from_tensor(input_t) for input_t in consumers_sgv.inputs ] reroute.swap_outputs(sgv_, output_placeholders) if control_outputs is not None: detach_control_outputs(sgv_, control_outputs) return sgv_, output_placeholders
def copy(sgv, dst_graph=None, dst_scope="", src_scope="", reuse_dst_scope=False): """Copy a subgraph. Args: sgv: the source subgraph-view. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope. reuse_dst_scope: if True the dst_scope is re-used if it already exists. Otherwise, the scope is given a unique name based on the one given by appending an underscore followed by a digit (default). Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; `info` is an instance of TransformerInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: TypeError: if `dst_graph` is not a `tf.Graph`. StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) if dst_graph is None: dst_graph = sgv.graph if not isinstance(dst_graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) copier = Transformer() return copier( sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)