def _reroute_sgv_inputs(sgv0, sgv1, mode): """Re-route all the inputs of two subgraphs. Args: sgv0: the first subgraph to have its inputs swapped. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. sgv1: the second subgraph to have its inputs swapped. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. mode: reroute mode, see _reroute_ts(...). Returns: A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped. Note that the function argument sgv0 and sgv1 are also modified in place. Raises: StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv0 = _subgraph.make_view(sgv0) sgv1 = _subgraph.make_view(sgv1) _util.check_graphs(sgv0, sgv1) can_modify = sgv0.ops + sgv1.ops # also allow consumers of passthrough to be modified: can_modify += _util.get_consuming_ops(sgv0.passthroughs) can_modify += _util.get_consuming_ops(sgv1.passthroughs) _reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify) _reroute_sgv_remap(sgv0, sgv1, mode) return sgv0, sgv1
def connect(sgv0, sgv1, disconnect_first=False): """Connect the outputs of sgv0 to the inputs of sgv1. Args: sgv0: the first subgraph to have its outputs swapped. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. Note that sgv0 is modified in place. sgv1: the second subgraph to have its outputs swapped. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. Note that sgv1 is modified in place. disconnect_first: if True the current outputs of sgv0 are disconnected. Returns: A tuple `(sgv0, sgv1)` of the now connected subgraphs. Raises: StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv0 = subgraph.make_view(sgv0) sgv1 = subgraph.make_view(sgv1) util.check_graphs(sgv0, sgv1) if disconnect_first: detach_outputs(sgv0) sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs) reroute.reroute_a2b_inputs(sgv0_outputs, sgv1) return sgv0, sgv1
def __call__(self, sgv, dst_graph, dst_scope, src_scope=""): """Execute the transformation. Args: sgv: the source subgraph-view. dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope, which specify the path from which the relative path of the transformed nodes are computed. For instance, if src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a relative path of x/y and will be transformed into b/x/y. Returns: The transformed subgraph view. Raises: ValueError: if the argumens are invalid. For instance, if the source and destination are the same. """ src_scope = util.scope_finalize(src_scope) dst_scope = util.scope_finalize(dst_scope) sgv = subgraph.make_view(sgv) if sgv.graph is dst_graph and not dst_scope: logging.warning("The source and the destination are the same! " "Beware: in-place transormation are currently " "experimental.") self._info = Transformer._Info(self, sgv, dst_graph, dst_scope, src_scope) # This is a recursive call. In practice, the graph is transformed bottom-up. for output_t in self._info.sgv.outputs: self._transform_t(output_t) sgv_ = self._transform_sgv(sgv) self._info = None return sgv_
def copy(sgv, dst_graph=None, dst_scope="", src_scope="", reuse_dst_scope=False): """Copy a subgraph. Args: sgv: the source subgraph-view. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope. reuse_dst_scope: if True the dst_scope is re-used if it already exists. Otherwise, the scope is given a unique name based on the one given by appending an underscore followed by a digit (default). Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; `info` is an instance of TransformerInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: TypeError: if `dst_graph` is not a `tf.Graph`. StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) if dst_graph is None: dst_graph = sgv.graph if not isinstance(dst_graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) copier = Transformer() return copier( sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)
def detach_outputs(sgv): """Detach the outputa of a subgraph view. Args: sgv: the subgraph view to be detached. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. Returns: A new subgraph view of the detached subgraph. Note that sgv is also modified in place. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) # only select outputs with consumers sgv_ = sgv.remap_outputs([output_id for output_id, output_t in enumerate(sgv.outputs) if output_t.consumers()]) # create consumer subgraph and remap consumers_sgv = subgraph.SubGraphView(sgv_.consumers()) consumers_sgv = consumers_sgv.remap_inputs( [input_id for input_id, input_t in enumerate(consumers_sgv.inputs) if input_t in sgv_.outputs]) with sgv_.graph.as_default(): output_placeholders = [ util.make_placeholder_from_tensor(input_t) for input_t in consumers_sgv.inputs ] return swap_outputs(sgv_, output_placeholders)
def copy(sgv, dst_graph=None, dst_scope="", src_scope="", reuse_dst_scope=False): """Copy a subgraph. Args: sgv: the source subgraph-view. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope. reuse_dst_scope: if True the dst_scope is re-used if it already exists. Otherwise, the scope is given a unique name based on the one given by postfixing an underscore followed by a digit (default). Returns: The subgraph view of the copied subgraph. Raises: TypeError: if dst_graph is not a tf.Graph. StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) if dst_graph is None: dst_graph = sgv.graph if not isinstance(dst_graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) copier = Transformer() return copier( sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)
def detach_inputs(sgv, control_inputs=False): """Detach the inputs of a subgraph view. Args: sgv: the subgraph view to be detached. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. Note that sgv is modified in place. control_inputs: if True control_inputs are also detached. Returns: A tuple `(sgv, input_placeholders)` where `sgv` is a new subgraph view of the detached subgraph; `input_placeholders` is a list of the created input placeholders. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) with sgv.graph.as_default(): input_placeholders = [ tf_array_ops.placeholder( dtype=input_t.dtype, name=util.placeholder_name(input_t)) for input_t in sgv.inputs ] reroute.swap_inputs(sgv, input_placeholders) if control_inputs: detach_control_inputs(sgv) return sgv, input_placeholders
def __call__(self, sgv, dst_graph, dst_scope, src_scope="", reuse_dst_scope=False): """Execute the transformation. Args: sgv: the source subgraph-view. dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope, which specify the path from which the relative path of the transformed nodes are computed. For instance, if src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a relative path of x/y and will be transformed into b/x/y. reuse_dst_scope: if True the dst_scope is re-used if it already exists. Otherwise, the scope is given a unique name based on the one given by appending an underscore followed by a digit (default). Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; `info` is an instance of Transformer.ResultInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: ValueError: if the arguments are invalid. """ sgv = subgraph.make_view(sgv) if not isinstance(dst_graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) src_scope = util.scope_finalize(src_scope) dst_scope = util.scope_finalize(dst_scope) # Potentially create new scope if reuse_dst_scope is False if dst_scope and not reuse_dst_scope: dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1])) # Create temporary info used during this transform call self._info = Transformer._Info(self, sgv, dst_graph, dst_scope, src_scope) # Transform the graph starting from the output tensors. for output_t in self._info.sgv.outputs: self._transform_t(output_t) # Some ops might have been missed by the previous walk, namely, the roots # without any outputs. So the walk is now finalized from those roots. remaining_ops = [op for op in self._info.sgv.ops if op not in self._info.transformed_ops] remaining_roots = [op for op in remaining_ops if not op.outputs] for op in remaining_roots: self._transform_op(op) sgv_ = self._transform_sgv(sgv) res_info = Transformer.ResultInfo(self._info) self._info = None return sgv_, res_info
def detach_control_inputs(sgv): """Detach all the external control inputs of the subgraph sgv. Args: sgv: the subgraph view to be detached. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) for op in sgv.ops: cops = [cop for cop in op.control_inputs if cop not in sgv.ops] reroute.remove_control_inputs(op, cops)
def __call__(self, sgv, dst_graph, dst_scope, src_scope="", reuse_dst_scope=False): """Execute the transformation. Args: sgv: the source subgraph-view. dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope, which specify the path from which the relative path of the transformed nodes are computed. For instance, if src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a relative path of x/y and will be transformed into b/x/y. reuse_dst_scope: if True the dst_scope is re-used if it already exists. Otherwise, the scope is given a unique name based on the one given by postfixing an underscore followed by a digit (default). Returns: A tuple `(sgv, ops_mapping)` where: `sgv` is the transformed subgraph view; `ops_mapping` is a dictionary mapping the name of the original ops to the name of the transformed ops. Raises: ValueError: if the argumens are invalid. """ sgv = subgraph.make_view(sgv) if not isinstance(dst_graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) src_scope = util.scope_finalize(src_scope) dst_scope = util.scope_finalize(dst_scope) # Potentially create new scope if reuse_dst_scope is False if dst_scope and not reuse_dst_scope: dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1])) if sgv.graph is dst_graph and not dst_scope: logging.warning("The source and the destination are the same! " "Beware: in-place transormation are currently " "experimental.") self._info = Transformer._Info(self, sgv, dst_graph, dst_scope, src_scope) # This is a recursive call. In practice, the graph is transformed bottom-up. for output_t in self._info.sgv.outputs: self._transform_t(output_t) sgv_ = self._transform_sgv(sgv) ops_mapping = self._info.create_ops_mapping() self._info = None return sgv_, ops_mapping
def __call__(self, sgv, dst_graph, dst_scope, src_scope="", reuse_dst_scope=False): """Execute the transformation. Args: sgv: the source subgraph-view. dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope, which specify the path from which the relative path of the transformed nodes are computed. For instance, if src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a relative path of x/y and will be transformed into b/x/y. reuse_dst_scope: if True the dst_scope is re-used if it already exists. Otherwise, the scope is given a unique name based on the one given by appending an underscore followed by a digit (default). Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; `info` is an instance of TransformerInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: ValueError: if the arguments are invalid. """ sgv = subgraph.make_view(sgv) if not isinstance(dst_graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) src_scope = util.scope_finalize(src_scope) dst_scope = util.scope_finalize(dst_scope) # Potentially create new scope if reuse_dst_scope is False if dst_scope and not reuse_dst_scope: dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1])) # Create temporary info used during this transform call info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope) self._copy_ops(info) self._finalize_cycles(info) self._connect_control_inputs(info) # Compute information about the transformation res_info = TransformerInfo(info) sgv_ = self._transform_sgv(info, sgv) return sgv_, res_info
def _reroute_sgv_outputs(sgv0, sgv1, mode): """Re-route all the outputs of two operations. Args: sgv0: the first subgraph to have its outputs swapped. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. sgv1: the second subgraph to have its outputs swapped. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. mode: reroute mode, see _reroute_ts(...). Returns: A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped. Note that the function argument sgv0 and sgv1 are also modified in place. Raises: StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv0 = _subgraph.make_view(sgv0) sgv1 = _subgraph.make_view(sgv1) _util.check_graphs(sgv0, sgv1) cannot_modify = sgv0.ops + sgv1.ops _reroute_ts(sgv0.outputs, sgv1.outputs, mode, cannot_modify=cannot_modify) return sgv0, sgv1
def copy_with_input_replacements(sgv, replacement_ts, dst_graph=None, dst_scope="", src_scope="", reuse_dst_scope=False): """Copy a subgraph, replacing some of its inputs. Note a replacement only happens if the tensor to be replaced is an input of the given subgraph. The inputs of a subgraph can be queried using sgv.inputs. Args: sgv: the source subgraph-view. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. replacement_ts: dictionary mapping from original tensors to the replaced one. dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope. reuse_dst_scope: if True the dst_scope is re-used if it already exists. Otherwise, the scope is given a unique name based on the one given by appending an underscore followed by a digit (default). Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; `info` is an instance of TransformerInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: TypeError: if dst_graph is not a tf.Graph. StandardError: if sgv cannot be converted to a SubGraphView using the same rules as the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) if dst_graph is None: dst_graph = sgv.graph if not isinstance(dst_graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) copier = Transformer() # Replace tensor if possible. def replace_t_with_replacement_handler(info, t): if t in replacement_ts: return replacement_ts[t] else: return keep_t_if_possible_handler(info, t) copier.transform_external_input_handler = replace_t_with_replacement_handler return copier( sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)
def detach_control_outputs(sgv, control_outputs): """Detach all the external control outputs of the subgraph sgv. Args: sgv: the subgraph view to be detached. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. control_outputs: a util.ControlOutputs instance. """ if not isinstance(control_outputs, util.ControlOutputs): raise TypeError("Expected a util.ControlOutputs, got: {}", type(control_outputs)) control_outputs.update() sgv = subgraph.make_view(sgv) for op in sgv.ops: for cop in control_outputs.get(op): if cop not in sgv.ops: reroute.remove_control_inputs(cop, op)
def bypass(sgv): """Bypass the given subgraph by connecting its inputs to its outputs. Args: sgv: the subgraph view to be bypassed. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. Returns: A new subgraph view of the bypassed subgraph. Note that sgv is also modified in place. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers sgv = subgraph.make_view(sgv) sgv_inputs = list(sgv.inputs) sgv, detached_inputs = detach_inputs(sgv) reroute.reroute_a2b_ts(sgv_inputs, sgv.outputs) return sgv, detached_inputs
def bypass(sgv): """Bypass the given subgraph by connecting its inputs to its outputs. Args: sgv: the subgraph view to be bypassed. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. Returns: A new subgraph view of the bypassed subgraph. Note that sgv is also modified in place. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers sgv = subgraph.make_view(sgv) sgv_inputs = list(sgv.inputs) sgv, detached_inputs = detach_inputs(sgv) reroute.reroute_a2b_ts(sgv_inputs, sgv.outputs) return sgv, detached_inputs
def remove(sgv, reconnect_after=False): """Remove sgv and optionally reconnect its inputs and outputs. Args: sgv: the subgraph view to be removed. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. reconnect_after: if False, the inputs and outputs of sgv are not reconnected after the removal. Returns: A new subgraph view of the removed subgraph. Note that sgv is also modified in place. Raises: StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) util.check_ts_compatibility(sgv.inputs, sgv.outputs) sgv, detached_inputs, detached_outputs = detach(sgv) if reconnect_after: connect(detached_inputs, detached_outputs) return sgv
def detach_outputs(sgv, control_outputs=None): """Detach the outputa of a subgraph view. Args: sgv: the subgraph view to be detached. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. Note that sgv is modified in place. control_outputs: a util.ControlOutputs instance or None. If not None the control outputs are also detached. Returns: A tuple `(sgv, output_placeholders)` where `sgv` is a new subgraph view of the detached subgraph; `output_placeholders` is a list of the created output placeholders. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) # only select outputs with consumers sgv_ = sgv.remap_outputs([output_id for output_id, output_t in enumerate(sgv.outputs) if output_t.consumers()]) # create consumer subgraph and remap consumers_sgv = subgraph.SubGraphView(sgv_.consumers()) consumers_sgv = consumers_sgv.remap_inputs( [input_id for input_id, input_t in enumerate(consumers_sgv.inputs) if input_t in sgv_.outputs]) with sgv_.graph.as_default(): output_placeholders = [ util.make_placeholder_from_tensor(input_t) for input_t in consumers_sgv.inputs ] reroute.swap_outputs(sgv_, output_placeholders) if control_outputs is not None: detach_control_outputs(sgv_, control_outputs) return sgv_, output_placeholders
def detach_inputs(sgv): """Detach the inputs of a subgraph view. Args: sgv: the subgraph view to be detached. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. Returns: A new subgraph view of the detached subgraph. Note that sgv is also modified in place. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) with sgv.graph.as_default(): input_placeholders = [ tf_array_ops.placeholder(dtype=input_t.dtype, name=util.placeholder_name(input_t)) for input_t in sgv.inputs ] return swap_inputs(sgv, input_placeholders)
def copy(sgv, dst_graph=None, dst_scope="", src_scope=""): """Copy a subgraph. Args: sgv: the source subgraph-view. This argument is converted to a subgraph using the same rules than the function subgraph.make_view. dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope. Returns: the subgraph view of the copied subgraph. Raises: TypeError: if dst_graph is not a tf.Graph. StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) if dst_graph is None: dst_graph = sgv.graph if not isinstance(dst_graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) copier = Transformer() return copier(sgv, dst_graph, dst_scope, src_scope)
def replace_grad_to_guided_grad(g): sgv = subgraph.make_view(g) with g.gradient_override_map(_grad_override_map): for op in sgv.ops: _replace_grad(g, op)