def connect(sgv0, sgv1, disconnect_first=False): """Connect the outputs of sgv0 to the inputs of sgv1. Args: sgv0: the first subgraph to have its outputs swapped. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. Note that sgv0 is modified in place. sgv1: the second subgraph to have its outputs swapped. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. Note that sgv1 is modified in place. disconnect_first: if True the current outputs of sgv0 are disconnected. Returns: A tuple `(sgv0, sgv1)` of the now connected subgraphs. Raises: StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv0 = subgraph.make_view(sgv0) sgv1 = subgraph.make_view(sgv1) util.check_graphs(sgv0, sgv1) if disconnect_first: detach_outputs(sgv0) sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs) reroute.reroute_inputs(sgv0_outputs, sgv1) return sgv0, sgv1
def detach_outputs(sgv, control_outputs=None): """Detach the output of a subgraph view. Args: sgv: the subgraph view to be detached. This argument is converted to a subgraph using the same rules as the function subgraph.make_view. Note that sgv is modified in place. control_outputs: a util.ControlOutputs instance or None. If not None the control outputs are also detached. Returns: A tuple `(sgv, output_placeholders)` where `sgv` is a new subgraph view of the detached subgraph; `output_placeholders` is a list of the created output placeholders. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ sgv = subgraph.make_view(sgv) # only select outputs with consumers sgv_ = sgv.remap_outputs([ output_id for output_id, output_t in enumerate(sgv.outputs) if output_t.consumers() ]) # create consumer subgraph and remap consumers_sgv = subgraph.SubGraphView(sgv_.consumers()) consumers_sgv = consumers_sgv.remap_inputs([ input_id for input_id, input_t in enumerate(consumers_sgv.inputs) if input_t in sgv_.outputs ]) with sgv_.graph.as_default(): output_placeholders = [ util.make_placeholder_from_tensor(input_t) for input_t in consumers_sgv.inputs ] reroute.swap_outputs(sgv_, output_placeholders) if control_outputs is not None: detach_control_outputs(sgv_, control_outputs) return sgv_, output_placeholders
def _transform_sgv(self, info, sgv): """Transform a subgraph view. For convenience, a transform operation returns a subgraph view of the transformed graph. Args: info: Temporary information for this transorfm call. sgv: the subgraph to be transformed. Returns: The transformed subgraph. """ ops_ = [op_ for _, op_ in iteritems(info.transformed_ops)] sgv_ = subgraph.SubGraphView(ops_) sgv_inputs_ = sgv_.inputs sgv_outputs_ = sgv_.outputs # re-order inputs input_map_ = [] for input_t in sgv.inputs: if input_t not in info.transformed_ts: continue input_t_ = info.transformed_ts[input_t] if input_t_ not in sgv_inputs_: continue input_t_index_ = sgv_.input_index(input_t_) input_map_.append(input_t_index_) # re-order outputs output_map_ = [] for output_t in sgv.outputs: if output_t not in info.transformed_ts: continue output_t_ = info.transformed_ts[output_t] if output_t_ not in sgv_outputs_: continue output_t_index_ = sgv_.output_index(output_t_) output_map_.append(output_t_index_) return sgv_.remap(input_map_, output_map_)