Beispiel #1
0
def connect(sgv0, sgv1, disconnect_first=False):
    """Connect the outputs of sgv0 to the inputs of sgv1.

  Args:
    sgv0: the first subgraph to have its outputs swapped. This argument is
      converted to a subgraph using the same rules as the function
      subgraph.make_view.
      Note that sgv0 is modified in place.
    sgv1: the second subgraph to have its outputs swapped. This argument is
      converted to a subgraph using the same rules as the function
      subgraph.make_view.
      Note that sgv1 is modified in place.
    disconnect_first: if True the current outputs of sgv0 are disconnected.
  Returns:
    A tuple `(sgv0, sgv1)` of the now connected subgraphs.
  Raises:
    StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
    sgv0 = subgraph.make_view(sgv0)
    sgv1 = subgraph.make_view(sgv1)
    util.check_graphs(sgv0, sgv1)
    if disconnect_first:
        detach_outputs(sgv0)
    sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs)
    reroute.reroute_inputs(sgv0_outputs, sgv1)
    return sgv0, sgv1
Beispiel #2
0
def _reroute_sgv_inputs(sgv0, sgv1, mode):
    """Re-route all the inputs of two subgraphs.

  Args:
    sgv0: the first subgraph to have its inputs swapped. This argument is
      converted to a subgraph using the same rules than the function
      subgraph.make_view.
    sgv1: the second subgraph to have its inputs swapped. This argument is
      converted to a subgraph using the same rules than the function
      subgraph.make_view.
    mode: reroute mode, see _reroute_ts(...).
  Returns:
    A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped.
      Note that the function argument sgv0 and sgv1 are also modified in place.
  Raises:
    StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
    sgv0 = _subgraph.make_view(sgv0)
    sgv1 = _subgraph.make_view(sgv1)
    _util.check_graphs(sgv0, sgv1)
    can_modify = sgv0.ops + sgv1.ops
    # also allow consumers of passthrough to be modified:
    can_modify += _util.get_consuming_ops(sgv0.passthroughs)
    can_modify += _util.get_consuming_ops(sgv1.passthroughs)
    _reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify)
    _reroute_sgv_remap(sgv0, sgv1, mode)
    return sgv0, sgv1
Beispiel #3
0
def detach_inputs(sgv, control_inputs=False):
    """Detach the inputs of a subgraph view.

  Args:
    sgv: the subgraph view to be detached. This argument is converted to a
      subgraph using the same rules as the function subgraph.make_view.
      Note that sgv is modified in place.
    control_inputs: if True control_inputs are also detached.
  Returns:
    A tuple `(sgv, input_placeholders)` where
      `sgv` is a new subgraph view of the detached subgraph;
      `input_placeholders` is a list of the created input placeholders.
  Raises:
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
    sgv = subgraph.make_view(sgv)

    with sgv.graph.as_default():
        input_placeholders = [
            tf_array_ops.placeholder(dtype=input_t.dtype,
                                     name=util.placeholder_name(input_t))
            for input_t in sgv.inputs
        ]

    reroute.swap_inputs(sgv, input_placeholders)
    if control_inputs:
        detach_control_inputs(sgv)
    return sgv, input_placeholders
Beispiel #4
0
def copy_with_input_replacements(sgv,
                                 replacement_ts,
                                 dst_graph=None,
                                 dst_scope="",
                                 src_scope="",
                                 reuse_dst_scope=False):
    """Copy a subgraph, replacing some of its inputs.

  Note a replacement only happens if the tensor to be replaced
  is an input of the given subgraph. The inputs of a subgraph can
  be queried using sgv.inputs.

  Args:
    sgv: the source subgraph-view. This argument is converted to a subgraph
      using the same rules as the function subgraph.make_view.
    replacement_ts: dictionary mapping from original tensors to the
      replaced one.
    dst_graph: the destination graph.
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A tuple `(sgv, info)` where:
      `sgv` is the transformed subgraph view;
      `info` is an instance of TransformerInfo containing
      information about the transform, including mapping between
      original and transformed tensors and operations.
  Raises:
    TypeError: if dst_graph is not a tf.Graph.
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules as the function subgraph.make_view.
  """
    sgv = subgraph.make_view(sgv)
    if dst_graph is None:
        dst_graph = sgv.graph
    if not isinstance(dst_graph, tf_ops.Graph):
        raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))

    copier = Transformer()

    # Replace tensor if possible.
    def replace_t_with_replacement_handler(info, t):
        if t in replacement_ts:
            return replacement_ts[t]
        else:
            return keep_t_if_possible_handler(info, t)

    copier.transform_external_input_handler = replace_t_with_replacement_handler
    return copier(sgv,
                  dst_graph,
                  dst_scope,
                  src_scope,
                  reuse_dst_scope=reuse_dst_scope)
Beispiel #5
0
def detach_control_inputs(sgv):
    """Detach all the external control inputs of the subgraph sgv.

  Args:
    sgv: the subgraph view to be detached. This argument is converted to a
      subgraph using the same rules as the function subgraph.make_view.
  """
    sgv = subgraph.make_view(sgv)
    for op in sgv.ops:
        cops = [cop for cop in op.control_inputs if cop not in sgv.ops]
        reroute.remove_control_inputs(op, cops)
Beispiel #6
0
    def __call__(self,
                 sgv,
                 dst_graph,
                 dst_scope,
                 src_scope="",
                 reuse_dst_scope=False):
        """Execute the transformation.

    Args:
      sgv: the source subgraph-view.
      dst_graph: the destination graph.
      dst_scope: the destination scope.
      src_scope: the source scope, which specify the path from which the
        relative path of the transformed nodes are computed. For instance, if
        src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a
        relative path of x/y and will be transformed into b/x/y.
      reuse_dst_scope: if True the dst_scope is re-used if it already exists.
        Otherwise, the scope is given a unique name based on the one given
        by appending an underscore followed by a digit (default).
    Returns:
      A tuple `(sgv, info)` where:
        `sgv` is the transformed subgraph view;
        `info` is an instance of TransformerInfo containing
        information about the transform, including mapping between
        original and transformed tensors and operations.
    Raises:
      ValueError: if the arguments are invalid.
    """
        sgv = subgraph.make_view(sgv)
        if not isinstance(dst_graph, tf_ops.Graph):
            raise TypeError("Expected a tf.Graph, got: {}".format(
                type(dst_graph)))

        src_scope = util.scope_finalize(src_scope)
        dst_scope = util.scope_finalize(dst_scope)

        # Potentially create new scope if reuse_dst_scope is False
        if dst_scope and not reuse_dst_scope:
            dst_scope = util.scope_finalize(
                dst_graph.unique_name(dst_scope[:-1]))

        # Create temporary info used during this transform call
        info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope)

        self._copy_ops(info)
        self._finalize_cycles(info)
        self._connect_control_inputs(info)

        # Compute information about the transformation
        res_info = TransformerInfo(info)
        sgv_ = self._transform_sgv(info, sgv)
        return sgv_, res_info
Beispiel #7
0
def _reroute_sgv_outputs(sgv0, sgv1, mode):
    """Re-route all the outputs of two operations.

  Args:
    sgv0: the first subgraph to have its outputs swapped. This argument is
      converted to a subgraph using the same rules than the function
      subgraph.make_view.
    sgv1: the second subgraph to have its outputs swapped. This argument is
      converted to a subgraph using the same rules than the function
      subgraph.make_view.
    mode: reroute mode, see _reroute_ts(...).
  Returns:
    A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped.
      Note that the function argument sgv0 and sgv1 are also modified in place.
  Raises:
    StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
    sgv0 = _subgraph.make_view(sgv0)
    sgv1 = _subgraph.make_view(sgv1)
    _util.check_graphs(sgv0, sgv1)
    cannot_modify = sgv0.ops + sgv1.ops
    _reroute_ts(sgv0.outputs, sgv1.outputs, mode, cannot_modify=cannot_modify)
    return sgv0, sgv1
Beispiel #8
0
def detach_control_outputs(sgv, control_outputs):
    """Detach all the external control outputs of the subgraph sgv.

  Args:
    sgv: the subgraph view to be detached. This argument is converted to a
      subgraph using the same rules as the function subgraph.make_view.
    control_outputs: a util.ControlOutputs instance.
  """
    if not isinstance(control_outputs, util.ControlOutputs):
        raise TypeError("Expected a util.ControlOutputs, got: {}",
                        type(control_outputs))
    control_outputs.update()
    sgv = subgraph.make_view(sgv)
    for op in sgv.ops:
        for cop in control_outputs.get(op):
            if cop not in sgv.ops:
                reroute.remove_control_inputs(cop, op)
Beispiel #9
0
def bypass(sgv):
    """Bypass the given subgraph by connecting its inputs to its outputs.

  Args:
    sgv: the subgraph view to be bypassed. This argument is converted to a
      subgraph using the same rules than the function subgraph.make_view.
      Note that sgv is modified in place.
  Returns:
    A tuple `(sgv, detached_inputs)` where:
      `sgv` is a new subgraph view of the bypassed subgraph;
      `detached_inputs` is a list of the created input placeholders.
  Raises:
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
    # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers
    sgv = subgraph.make_view(sgv)
    sgv_inputs = list(sgv.inputs)
    sgv, detached_inputs = detach_inputs(sgv)
    reroute.reroute_ts(sgv_inputs, sgv.outputs)
    return sgv, detached_inputs
Beispiel #10
0
def detach_outputs(sgv, control_outputs=None):
    """Detach the output of a subgraph view.

  Args:
    sgv: the subgraph view to be detached. This argument is converted to a
      subgraph using the same rules as the function subgraph.make_view.
      Note that sgv is modified in place.
    control_outputs: a util.ControlOutputs instance or None. If not None the
      control outputs are also detached.
  Returns:
    A tuple `(sgv, output_placeholders)` where
      `sgv` is a new subgraph view of the detached subgraph;
      `output_placeholders` is a list of the created output placeholders.
  Raises:
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
    sgv = subgraph.make_view(sgv)
    # only select outputs with consumers
    sgv_ = sgv.remap_outputs([
        output_id for output_id, output_t in enumerate(sgv.outputs)
        if output_t.consumers()
    ])
    # create consumer subgraph and remap
    consumers_sgv = subgraph.SubGraphView(sgv_.consumers())
    consumers_sgv = consumers_sgv.remap_inputs([
        input_id for input_id, input_t in enumerate(consumers_sgv.inputs)
        if input_t in sgv_.outputs
    ])

    with sgv_.graph.as_default():
        output_placeholders = [
            util.make_placeholder_from_tensor(input_t)
            for input_t in consumers_sgv.inputs
        ]

    reroute.swap_outputs(sgv_, output_placeholders)
    if control_outputs is not None:
        detach_control_outputs(sgv_, control_outputs)
    return sgv_, output_placeholders
Beispiel #11
0
def copy(sgv,
         dst_graph=None,
         dst_scope="",
         src_scope="",
         reuse_dst_scope=False):
    """Copy a subgraph.

  Args:
    sgv: the source subgraph-view. This argument is converted to a subgraph
      using the same rules than the function subgraph.make_view.
    dst_graph: the destination graph.
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A tuple `(sgv, info)` where:
      `sgv` is the transformed subgraph view;
      `info` is an instance of TransformerInfo containing
      information about the transform, including mapping between
      original and transformed tensors and operations.
  Raises:
    TypeError: if `dst_graph` is not a `tf.Graph`.
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
    sgv = subgraph.make_view(sgv)
    if dst_graph is None:
        dst_graph = sgv.graph
    if not isinstance(dst_graph, tf_ops.Graph):
        raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))

    copier = Transformer()
    return copier(sgv,
                  dst_graph,
                  dst_scope,
                  src_scope,
                  reuse_dst_scope=reuse_dst_scope)