def __init__( self, sgv, dst_graph, # type: Graph dst_scope, src_scope): self.sgv = sgv self.sgv_inputs_set = frozenset(sgv.inputs) self.ops = frozenset(sgv.ops) self.control_outputs = util.ControlOutputs(sgv.graph) self.graph = sgv.graph # type: Graph self.scope = src_scope self.graph_ = dst_graph self.scope_ = dst_scope self.transformed_ops = {} self.transformed_ts = {} self.collections = dict( (key, self.graph.get_collection_by_name(key)) for key in self.graph.get_all_collection_keys()) self.cyclic_ops = [] self.transform_original_op_handler = transform_op_if_inside_handler # The graph is transformed op by op, in the same order the original ops # were created. However, this is sometimes not possible due to cycles # (i.e. while loops). So when the transformer creates a new op whose # inputs do not exist yet, temporary placeholders are created and stored # in this `tmp_cyclic_ts` container. During a second pass, # those temporary tensors are replaced by the proper transformed tensors # (see the function `_finalize_cycles`). self.tmp_cyclic_ts = []
def graph_replace(target_ts, replacement_ts, dst_scope="", src_scope="", reuse_dst_scope=False): """Create a new graph which computes the targets from the replaced Tensors. Args: target_ts: a single gde.Tensor or an iterable of gde.Tensor. replacement_ts: dictionary mapping from original tensors to replaced tensors dst_scope: the destination scope. src_scope: the source scope. reuse_dst_scope: if True the dst_scope is re-used if it already exists. Otherwise, the scope is given a unique name based on the one given by appending an underscore followed by a digit (default). Returns: A single gde.Tensor or a list of target gde.Tensor, depending on the type of the input argument `target_ts`. The returned tensors are recomputed using the tensors from replacement_ts. Raises: ValueError: if the targets are not connected to replacement_ts. """ # Identify operations in the graph that will change. # Start forward walk at Tensors that will be replaced, and # backward walk at the target output Tensors. flatten_target_ts = _flatten_tree(target_ts) # Construct the forward control dependencies edges so that # the get_walks_intersection_ops can also traverse the # control dependencies. graph = util.get_unique_graph(flatten_target_ts, check_types=(Tensor, )) control_ios = util.ControlOutputs(graph) ops = select.get_walks_intersection_ops(list(iterkeys(replacement_ts)), flatten_target_ts, control_ios=control_ios) if not ops: raise ValueError("Targets and replacements are not connected!") # Complete ops to avoid malformed control flow. # TODO(fkp): Consider moving this function deeper (in the transformer?). _add_control_flow_ops(ops, control_ios) # Create a copy of the relevant subgraph unused_sgv_, info = copy_with_input_replacements(ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope) # Return the transformed targets but keep the original if the transformed # counterpart cannot be found def missing_fn(original_t): return original_t return info.transformed(target_ts, missing_fn)