def __init__(self, inside_ops=(), passthrough_ts=()): """Create a subgraph containing the given ops and the "passthrough" tensors. Args: inside_ops: an object convertible to a list of tf.Operation. This list defines all the operations in the subgraph. passthrough_ts: an object convertible to a list of tf.Tensor. This list define all the "passthrough" tensors. A passthrough tensor is a tensor which goes directly from the input of the subgraph to it output, without any intermediate operations. All the non passthrough tensors are silently ignored. Raises: TypeError: if inside_ops cannot be converted to a list of tf.Operation or if passthrough_ts cannot be converted to a list of tf.Tensor. """ inside_ops = util.make_list_of_op(inside_ops) passthrough_ts = util.make_list_of_t(passthrough_ts) ops_and_ts = inside_ops + passthrough_ts if ops_and_ts: self._graph = util.get_unique_graph(ops_and_ts) else: self._graph = None self._ops = inside_ops # Compute inside and outside tensor inputs, outputs, insides = select.compute_boundary_ts(inside_ops) # Compute passthrough tensors, silently ignoring the non-passthrough ones. all_tensors = frozenset(inputs + outputs + list(insides)) self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors] # Set inputs and outputs. self._input_ts = inputs + self._passthrough_ts self._output_ts = outputs + self._passthrough_ts
def _remap_default(self, remove_input_map=True, remove_output_map=True): """Remap in the place the inputs and/or outputs to the default mapping. Args: remove_input_map: if True the input map is reset to the default one. remove_output_map: if True the output map is reset to the default one. """ if not remove_input_map and not remove_output_map: return # Compute inside and outside tensor inputs, outputs, _ = select.compute_boundary_ts(self._ops) if remove_input_map: self._input_ts = list(inputs) + self._passthrough_ts if remove_output_map: self._output_ts = list(outputs) + self._passthrough_ts
def __init__(self, inside_ops=(), passthrough_ts=()): """Create a subgraph containing the given ops and the "passthrough" tensors. Args: inside_ops: an object convertible to a list of `tf.Operation`. This list defines all the operations in the subgraph. passthrough_ts: an object convertible to a list of `tf.Tensor`. This list define all the "passthrough" tensors. A passthrough tensor is a tensor which goes directly from the input of the subgraph to it output, without any intermediate operations. All the non passthrough tensors are silently ignored. Raises: TypeError: if inside_ops cannot be converted to a list of `tf.Operation` or if `passthrough_ts` cannot be converted to a list of `tf.Tensor`. """ inside_ops = util.make_list_of_op(inside_ops) passthrough_ts = util.make_list_of_t(passthrough_ts) ops_and_ts = inside_ops + passthrough_ts if ops_and_ts: self._graph = util.get_unique_graph(ops_and_ts) self._ops = inside_ops # Compute inside and outside tensor inputs, outputs, insides = select.compute_boundary_ts(inside_ops) # Compute passthrough tensors, silently ignoring the non-passthrough ones. all_tensors = frozenset(inputs + outputs + list(insides)) self._passthrough_ts = [ t for t in passthrough_ts if t not in all_tensors ] # Set inputs and outputs. self._input_ts = inputs + self._passthrough_ts self._output_ts = outputs + self._passthrough_ts else: self._graph = None self._passthrough_ts = [] self._input_ts = [] self._output_ts = [] self._ops = []
def unmap(self, remove_input_map=True, remove_output_map=True): """Unmap existing input and/or output mapping. Args: remove_input_map: if True the input map is reset to identity. remove_output_map: if True the output map is reset to identity. Returns: A new modified instance of the original subgraph view with its input and/or output mapping reset to identity. """ res = self.copy() if not remove_input_map and not remove_output_map: return res # Compute inside and outside tensor inputs, outputs, _ = select.compute_boundary_ts(self._ops, keep_order=True) if remove_input_map: self._input_ts = list(inputs) + self._passthrough_ts if remove_output_map: self._output_ts = list(outputs) + self._passthrough_ts return res