def graph_replace(target_ts, replacement_ts, dst_scope="", src_scope="", reuse_dst_scope=False): """Create a new graph which computes the targets from the replaced Tensors. Args: target_ts: a single gde.Tensor or an iterable of gde.Tensor. replacement_ts: dictionary mapping from original tensors to replaced tensors dst_scope: the destination scope. src_scope: the source scope. reuse_dst_scope: if True the dst_scope is re-used if it already exists. Otherwise, the scope is given a unique name based on the one given by appending an underscore followed by a digit (default). Returns: A single gde.Tensor or a list of target gde.Tensor, depending on the type of the input argument `target_ts`. The returned tensors are recomputed using the tensors from replacement_ts. Raises: ValueError: if the targets are not connected to replacement_ts. """ # Identify operations in the graph that will change. # Start forward walk at Tensors that will be replaced, and # backward walk at the target output Tensors. flatten_target_ts = _flatten_tree(target_ts) # Construct the forward control dependencies edges so that # the get_walks_intersection_ops can also traverse the # control dependencies. graph = util.get_unique_graph(flatten_target_ts, check_types=(Tensor, )) control_ios = util.ControlOutputs(graph) ops = select.get_walks_intersection_ops(list(iterkeys(replacement_ts)), flatten_target_ts, control_ios=control_ios) if not ops: raise ValueError("Targets and replacements are not connected!") # Complete ops to avoid malformed control flow. # TODO(fkp): Consider moving this function deeper (in the transformer?). _add_control_flow_ops(ops, control_ios) # Create a copy of the relevant subgraph unused_sgv_, info = copy_with_input_replacements(ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope) # Return the transformed targets but keep the original if the transformed # counterpart cannot be found def missing_fn(original_t): return original_t return info.transformed(target_ts, missing_fn)
def __init__(self, inside_ops=(), passthrough_ts=()): """Create a subgraph containing the given ops and the "passthrough" tensors. Args: inside_ops: an object convertible to a list of `gde.Node`. This list defines all the operations in the subgraph. passthrough_ts: an object convertible to a list of `gde.Tensor`. This list define all the "passthrough" tensors. A passthrough tensor is a tensor which goes directly from the input of the subgraph to it output, without any intermediate operations. All the non passthrough tensors are silently ignored. Raises: TypeError: if inside_ops cannot be converted to a list of `gde.Node` or if `passthrough_ts` cannot be converted to a list of `gde.Tensor`. """ inside_ops = util.make_list_of_op(inside_ops) passthrough_ts = util.make_list_of_t(passthrough_ts) ops_and_ts = inside_ops + passthrough_ts if ops_and_ts: self._graph = util.get_unique_graph(ops_and_ts) self._ops = inside_ops # Compute inside and outside tensor inputs, outputs, insides = select.compute_boundary_ts(inside_ops) # Compute passthrough tensors, silently ignoring the non-passthrough ones. all_tensors = frozenset(inputs + outputs + list(insides)) self._passthrough_ts = [ t for t in passthrough_ts if t not in all_tensors ] # Set inputs and outputs. self._input_ts = inputs + self._passthrough_ts self._output_ts = outputs + self._passthrough_ts else: self._graph = None self._passthrough_ts = [] self._input_ts = [] self._output_ts = [] self._ops = []