def graph_replace(target_ts,
                  replacement_ts,
                  dst_scope="",
                  src_scope="",
                  reuse_dst_scope=False):
    """Create a new graph which computes the targets from the replaced Tensors.

  Args:
    target_ts: a single gde.Tensor or an iterable of gde.Tensor.
    replacement_ts: dictionary mapping from original tensors to replaced tensors
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A single gde.Tensor or a list of target gde.Tensor, depending on
    the type of the input argument `target_ts`.
    The returned tensors are recomputed using the tensors from replacement_ts.
  Raises:
    ValueError: if the targets are not connected to replacement_ts.
  """
    # Identify operations in the graph that will change.
    # Start forward walk at Tensors that will be replaced, and
    # backward walk at the target output Tensors.
    flatten_target_ts = _flatten_tree(target_ts)
    # Construct the forward control dependencies edges so that
    # the get_walks_intersection_ops can also traverse the
    # control dependencies.
    graph = util.get_unique_graph(flatten_target_ts, check_types=(Tensor, ))
    control_ios = util.ControlOutputs(graph)
    ops = select.get_walks_intersection_ops(list(iterkeys(replacement_ts)),
                                            flatten_target_ts,
                                            control_ios=control_ios)
    if not ops:
        raise ValueError("Targets and replacements are not connected!")

    # Complete ops to avoid malformed control flow.
    # TODO(fkp): Consider moving this function deeper (in the transformer?).
    _add_control_flow_ops(ops, control_ios)

    # Create a copy of the relevant subgraph
    unused_sgv_, info = copy_with_input_replacements(ops, replacement_ts, None,
                                                     dst_scope, src_scope,
                                                     reuse_dst_scope)

    # Return the transformed targets but keep the original if the transformed
    # counterpart cannot be found
    def missing_fn(original_t):
        return original_t

    return info.transformed(target_ts, missing_fn)
Exemple #2
0
    def __init__(self, inside_ops=(), passthrough_ts=()):
        """Create a subgraph containing the given ops and the "passthrough" tensors.

    Args:
      inside_ops: an object convertible to a list of `gde.Node`. This list
        defines all the operations in the subgraph.
      passthrough_ts: an object convertible to a list of `gde.Tensor`. This list
        define all the "passthrough" tensors. A passthrough tensor is a tensor
        which goes directly from the input of the subgraph to it output, without
        any intermediate operations. All the non passthrough tensors are
        silently ignored.
    Raises:
      TypeError: if inside_ops cannot be converted to a list of `gde.Node`
        or if `passthrough_ts` cannot be converted to a list of `gde.Tensor`.
    """

        inside_ops = util.make_list_of_op(inside_ops)
        passthrough_ts = util.make_list_of_t(passthrough_ts)
        ops_and_ts = inside_ops + passthrough_ts
        if ops_and_ts:
            self._graph = util.get_unique_graph(ops_and_ts)
            self._ops = inside_ops

            # Compute inside and outside tensor
            inputs, outputs, insides = select.compute_boundary_ts(inside_ops)

            # Compute passthrough tensors, silently ignoring the non-passthrough ones.
            all_tensors = frozenset(inputs + outputs + list(insides))
            self._passthrough_ts = [
                t for t in passthrough_ts if t not in all_tensors
            ]

            # Set inputs and outputs.
            self._input_ts = inputs + self._passthrough_ts
            self._output_ts = outputs + self._passthrough_ts
        else:
            self._graph = None
            self._passthrough_ts = []
            self._input_ts = []
            self._output_ts = []
            self._ops = []