Example #1
0
 def __init__(
         self,
         sgv,
         dst_graph,  # type: Graph
         dst_scope,
         src_scope):
     self.sgv = sgv
     self.sgv_inputs_set = frozenset(sgv.inputs)
     self.ops = frozenset(sgv.ops)
     self.control_outputs = util.ControlOutputs(sgv.graph)
     self.graph = sgv.graph  # type: Graph
     self.scope = src_scope
     self.graph_ = dst_graph
     self.scope_ = dst_scope
     self.transformed_ops = {}
     self.transformed_ts = {}
     self.collections = dict(
         (key, self.graph.get_collection_by_name(key))
         for key in self.graph.get_all_collection_keys())
     self.cyclic_ops = []
     self.transform_original_op_handler = transform_op_if_inside_handler
     # The graph is transformed op by op, in the same order the original ops
     # were created. However, this is sometimes not possible due to cycles
     # (i.e. while loops). So when the transformer creates a new op whose
     # inputs do not exist yet, temporary placeholders are created and stored
     # in this `tmp_cyclic_ts` container. During a second pass,
     # those temporary tensors are replaced by the proper transformed tensors
     # (see the function `_finalize_cycles`).
     self.tmp_cyclic_ts = []
def graph_replace(target_ts,
                  replacement_ts,
                  dst_scope="",
                  src_scope="",
                  reuse_dst_scope=False):
    """Create a new graph which computes the targets from the replaced Tensors.

  Args:
    target_ts: a single gde.Tensor or an iterable of gde.Tensor.
    replacement_ts: dictionary mapping from original tensors to replaced tensors
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A single gde.Tensor or a list of target gde.Tensor, depending on
    the type of the input argument `target_ts`.
    The returned tensors are recomputed using the tensors from replacement_ts.
  Raises:
    ValueError: if the targets are not connected to replacement_ts.
  """
    # Identify operations in the graph that will change.
    # Start forward walk at Tensors that will be replaced, and
    # backward walk at the target output Tensors.
    flatten_target_ts = _flatten_tree(target_ts)
    # Construct the forward control dependencies edges so that
    # the get_walks_intersection_ops can also traverse the
    # control dependencies.
    graph = util.get_unique_graph(flatten_target_ts, check_types=(Tensor, ))
    control_ios = util.ControlOutputs(graph)
    ops = select.get_walks_intersection_ops(list(iterkeys(replacement_ts)),
                                            flatten_target_ts,
                                            control_ios=control_ios)
    if not ops:
        raise ValueError("Targets and replacements are not connected!")

    # Complete ops to avoid malformed control flow.
    # TODO(fkp): Consider moving this function deeper (in the transformer?).
    _add_control_flow_ops(ops, control_ios)

    # Create a copy of the relevant subgraph
    unused_sgv_, info = copy_with_input_replacements(ops, replacement_ts, None,
                                                     dst_scope, src_scope,
                                                     reuse_dst_scope)

    # Return the transformed targets but keep the original if the transformed
    # counterpart cannot be found
    def missing_fn(original_t):
        return original_t

    return info.transformed(target_ts, missing_fn)