Пример #1
0
 def adjust_in_out_chw(self, G, node, names):
     self.verify_chw(node, names)
     trans = self.get_trans(names, ['c', 'h', 'w'])
     in_dim = node.in_dims[0]
     if in_dim.c != 1:
         self.apply_input_trans(node, trans, index=0)
     else:
         reshape = ReshapeParameters(f'{node.name}_r_chw',
                                     old_shape=in_dim.clone(),
                                     shape=Dim.named_ordered(c=in_dim.c,
                                                             h=in_dim.h,
                                                             w=in_dim.w))
         G.insert_node_before(reshape, node, edge_class=NNEdge)
         self.check_quantization(G, node, reshape)
     out_dim = node.out_dims[0]
     if out_dim.c != 1:
         self.apply_output_trans(node, self.invert(trans), index=0)
     else:
         reshape = ReshapeParameters(f'{node.name}_r_{"".join(names)}',
                                     old_shape=Dim.named_ordered(
                                         c=out_dim.c,
                                         h=out_dim.h,
                                         w=out_dim.w),
                                     shape=out_dim.clone())
         G.insert_node_after(node, reshape, edge_class=NNEdge)
         self.check_quantization(G, node, reshape, dir='out')
Пример #2
0
    def _common1_11(cls, node, **kwargs):
        axis = node.attrs.get('axis', 1)
        all_nodes = kwargs['all_nodes']
        G = kwargs['G']
        valid_name = kwargs['valid_name']
        x = all_nodes[node.input[0]]
        x_shape = x[2].shape
        if axis < 0:
            axis += len(x_shape)
        old_shape = cls._get_real_dim(x_shape)
        # v 1 and 11 work differently to v13. In v1 and v11 the input is collected into a 2D tensor
        # based on the axis [a_0, a_1, ..., a_{k-1}, a_k, ..., a_{n-1}]  with axis k
        # becomes [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]
        # This is used for the softmax
        new_pshape = [condense(x_shape[:axis:]), condense(x_shape[axis::])]
        new_shape = cls._get_real_dim(new_pshape)
        reshape_1 = ReshapeParameters(valid_name + "_reshape1",
                                      old_shape=old_shape,
                                      shape=new_shape)
        G.add_edge(
            NNEdge(from_node=x[0], to_node=reshape_1, from_idx=x[1], to_idx=0))
        # operation axis will either be 1 or 0
        softmax = SoftMaxParameters(valid_name, axis=len(new_shape) - 1)
        G.add_edge(NNEdge(from_node=reshape_1, to_node=softmax))
        reshape_2 = ReshapeParameters(valid_name + "_reshape2",
                                      old_shape=new_shape,
                                      shape=old_shape)
        G.add_edge(NNEdge(from_node=softmax, to_node=reshape_2))

        all_nodes[node.output[0]] = (reshape_2, 0, ProvisionalDim(x_shape))
        return softmax
Пример #3
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for node in [
                node for node in G.nodes(node_classes=StridedSliceParameters)
        ]:
            if node.slice_shape != tuple(node.in_dims[0].shape):
                continue
            has_modified_graph = True
            nid = NodeId(node)
            if node.slice_shape == node.out_shape:
                LOG.info(
                    f'removing strided slice {node.name} that does nothing')
                G.remove_and_reconnect(node, edge_class=NNEdge)
                if G.quantization and nid in G.quantization:
                    del G.quantization[nid]
            else:
                reshape = ReshapeParameters(
                    G.unique_name(f'{node.name}_reshape'),
                    old_shape=node.slice_shape,
                    shape=node.out_shape)
                LOG.info(
                    f'replacing strided slice {node.name} with reshape {reshape.name}'
                )
                G.replace_node(node, reshape)
                if G.quantization and nid in G.quantization:
                    G.quantization[NodeId(reshape)] = G.quantization[nid]
                    del G.quantization[nid]

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Пример #4
0
    def get_state(cls, G, inputs, idx, name, hidden_size, num_directions=1):
        if not inputs[idx]:
            state = np.zeros((num_directions, hidden_size))
        elif cls.is_constant(inputs[idx]):
            state = cls.get_constant(inputs[idx])
        else:
            state_inp = inputs[idx]
            if num_directions == 2:
                act_slices = (((0, 1, 1), (0, hidden_size, 1)),
                              ((1, 2, 1), (0, hidden_size, 1)))
                out_shapes = ((1, hidden_size), (1, hidden_size))
                split = SplitParameters(G.unique_name(f'{name}_split'),
                                        act_slices=act_slices,
                                        out_shapes=out_shapes,
                                        axis=0)
                G.add_edge(
                    NNEdge(from_node=state_inp[0],
                           to_node=split,
                           from_idx=state_inp[1]))
                return {
                    'forward': {
                        name: (split, 0)
                    },
                    'backward': {
                        name: (split, 0)
                    },
                }
            else:
                reshape = ReshapeParameters(G.unique_name(f'{name}_reshape'),
                                            old_shape=(1, hidden_size),
                                            shape=(hidden_size, ))
                G.add_edge(
                    NNEdge(from_node=state_inp[0],
                           to_node=reshape,
                           from_idx=state_inp[1]))
                return {
                    'forward': {
                        name: (reshape, 0)
                    },
                }

        state = cls.get_constant(inputs[idx]) if inputs[idx] else np.zeros(
            (num_directions, hidden_size))
        return {
            'forward' if dir == 0 else 'backward': {
                name: dir_arr.reshape((hidden_size, ))
            }
            for dir, dir_arr in enumerate(
                np.split(state, num_directions, axis=0))
        }
Пример #5
0
 def adjust_in_out_order(self, G, node, names, order):
     self.verify_chw(node, names)
     trans = self.get_trans(names, order)
     in_dim = node.in_dims[0]
     if in_dim.c != 1:
         self.apply_input_trans(node, trans, index=0)
     else:
         new_shape = {k: getattr(in_dim, k) for k in order}
         reshape = ReshapeParameters(
             f'{node.name}_r_{"".join(in_dim.order)}_{"".join(order)}',
             old_shape=in_dim.clone(),
             shape=Dim.named_ordered(**new_shape)
         )
         G.insert_node_before(
             reshape,
             node,
             edge_class=NNEdge
         )
         node.in_dims_hint[0] = order
         self.check_quantization(G, node, reshape)
     out_dim = node.out_dims[0]
     if out_dim.c != 1:
         self.apply_output_trans(node, self.invert(trans), index=0)
     else:
         old_shape = {k: getattr(out_dim, k) for k in order}
         node.out_dims_hint[0] = order
         reshape = ReshapeParameters(
             f'{node.name}_r_{"".join(names)}',
             old_shape=Dim.named_ordered(**old_shape),
             shape=out_dim.clone()
         )
         G.insert_node_after(
             node,
             reshape,
             edge_class=NNEdge
         )
         self.check_quantization(G, node, reshape, direction='out')
    def _execute(self, node, G):
        info(f"{self}")
        direction = self.direction
        if self.reshape_from is not None:
            params = ReshapeParameters(G.unique_name(f'{node.name}_reshape'),
                                       old_shape=Dim.unnamed(
                                           self.reshape_from),
                                       shape=Dim.reshape_to(self.reshape_to))
            self.do_insert(node, G, params, direction=direction)
            node = params
            direction = "out"

        params = TransposeParameters(G.unique_name(f'{node.name}_trans'),
                                     transpose=self.transpose)
        self.do_insert(node, G, params, direction=direction)
Пример #7
0
 def _maybe_insert_reshape(cls, G, inp, inp_shape, pout_shape):
     out_dim_none = tuple(
         [idx for idx, dim in enumerate(pout_shape) if dim is None])
     in_dim_none = tuple(
         [idx for idx, dim in enumerate(inp_shape) if dim is None])
     if out_dim_none == in_dim_none:
         return inp[0], inp[1]
     old_shape = [dim for dim in inp_shape if dim is not None]
     shape = [
         dim for idx, dim in enumerate(inp_shape)
         if dim is not None and idx not in out_dim_none
     ]
     rparams = ReshapeParameters(G.unique_name(f'{inp[0]}_reshape'),
                                 old_shape=old_shape,
                                 shape=shape)
     G.add_edge(NNEdge(from_node=inp[0], to_node=rparams, from_idx=inp[1]))
     return rparams, 0
Пример #8
0
 def remove_known_batch_dimension(cls, G, x, node, batch_axis=0):
     x_shape = x[2].shape
     if x_shape[batch_axis] is not None:
         if x_shape[0] > 1:
             raise ValueError(
                 f'multi batch (n={x_shape[batch_axis]}) operations are not supported by {node.name}')
         rparams = ReshapeParameters(
             f'{node.name}_batch',
             old_shape=Dim.unnamed(x_shape),
             shape=Dim.unnamed(x_shape[0:batch_axis:]+x_shape[batch_axis+1::]))
         if G.quantization:
             qrec = G.quantization[NodeId(x[0])]
             G.quantization[NodeId(rparams)] = QRec.copy_ktype(
                 qrec,
                 in_qs=[qrec.out_qs[0]],
                 out_qs=[qrec.out_qs[0]])
         G.add_edge(
             NNEdge(from_node=x[0], to_node=rparams, from_idx=x[1], to_idx=0))
         return (rparams, 0, ProvisionalDim(x_shape[0:batch_axis:]+[None]+x_shape[batch_axis+1::]))
     else:
         return x
    def _execute(self, node, G):
        info(f"{self}")
        if node.name not in G:
            return
        if self.reshape:
            reshape_node = ReshapeParameters(
                G.unique_name(f"{node.name}_reshape"),
                old_shape=self.reshape[0],
                shape=self.reshape[1])
            G.replace_node(node.name, reshape_node)
            if G.quantization:
                G.quantization.copy_qrec(node, 'in', 0, reshape_node)
            LOG.info(
                f'transpose {node.name} replaced with reshape {reshape_node.name}'
            )
        else:
            G.remove_and_reconnect(node, edge_class=NNEdge)
            nid = NodeId(node)
            if G.quantization and nid in G.quantization:
                del G.quantization[nid]

            LOG.info(f'transpose {node.name} removed')
Пример #10
0
 def _maybe_insert_reshape(cls, G, inp, inp_shape, pout_shape):
     out_dim_none = tuple(
         [idx for idx, dim in enumerate(pout_shape) if dim is None])
     in_dim_none = tuple(
         [idx for idx, dim in enumerate(inp_shape) if dim is None])
     if out_dim_none == in_dim_none:
         return inp[0], inp[1]
     old_shape = [dim for dim in inp_shape if dim is not None]
     new_shape = [
         dim for idx, dim in enumerate(inp_shape)
         if dim is not None and idx not in out_dim_none
     ]
     if cls.is_constant(inp):
         val = np.reshape(cls.get_constant(inp), new_shape)
         params = ConstantInputParameters(G.unique_name(inp[0].name),
                                          value=val,
                                          dims=Dim.unnamed(val.shape))
     else:
         params = ReshapeParameters(G.unique_name(f'{inp[0].name}_reshape'),
                                    old_shape=old_shape,
                                    shape=new_shape)
         G.add_edge(
             NNEdge(from_node=inp[0], to_node=params, from_idx=inp[1]))
     return params, 0
Пример #11
0
 def _execute(self, node, G):
     LOGL("%s", str(self))
     if isinstance(node, TransposeParameters):
         direction = "in"
     else:
         direction = self.direction
     transpose_name = "transpose_" + direction
     trans = getattr(node, transpose_name)
     if direction == 'in' and isinstance(node, Broadcastable):
         node.delete_transpose(self.idx, trans[self.idx])
     trans[self.idx] = None
     if all(t is None for t in trans):
         setattr(node, transpose_name, None)
     if self.reshape:
         reshape_node = ReshapeParameters(
             G.unique_name(f"{node.name}_reshape"),
             old_shape=self.reshape[0],
             shape=self.reshape[1])
         if direction == "in":
             in_edge = G.indexed_in_edges(node.name)[self.idx]
             G.insert_node_at_edge(reshape_node, in_edge, edge_class=NNEdge)
             if G.quantization:
                 G.quantization.copy_qrec(in_edge.to_node, 'in', self.idx,
                                          reshape_node)
         else:
             G.insert_node_after(node,
                                 reshape_node,
                                 from_idx=self.idx,
                                 edge_class=NNEdge)
             if G.quantization:
                 G.quantization.copy_qrec(node, 'out', self.idx,
                                          reshape_node)
     if isinstance(node, TransposeParameters) and node.transpose_in is None:
         LOGL("remove null transpose %s", node.name)
         G.remove_and_reconnect(node, edge_class=NNEdge)
     LOG.info('transpose is now %s', getattr(node, transpose_name))
Пример #12
0
    def _common(cls, node: TFLiteNode, **kwargs):
        node_opts = node.get_options(PackOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]
        inp_shapes = [input[2].shape for input in inputs]

        values_count = node_opts.ValuesCount()
        check(
            len(inputs) == values_count,
            "invalid tflite file - values count not equal to inputs")

        buffer_idxes = [tensor.buffer_idx for tensor in node.input]
        check(
            len(set(buffer_idxes)) == len(buffer_idxes),
            "packs with multiple versions of the same input are not supported. This is normally a graph design problem."
        )

        axis = node_opts.Axis()
        dimension_size = len(inp_shapes)
        if axis < 0:
            axis += dimension_size

        check(all(shape == inp_shapes[0] for shape in inp_shapes[1::]),
              "invalid tflite file - pack inputs not the same")

        # prepare shapes of all tensors
        pconcat_out_shape = inp_shapes[0].copy()
        pconcat_out_shape.insert(axis, values_count)

        pconcat_in_shape = inp_shapes[0].copy()
        pconcat_in_shape.insert(axis, 1)

        preshape_in_shape = inp_shapes[0].copy()

        # remove nones from constants
        cls.remove_none_from_constants(inputs, preshape_in_shape)

        # remove nones from reshape shapes
        reshape_in_shape = cls.remove_unspecified_dim(preshape_in_shape)
        concat_in_shape = cls.remove_unspecified_dim(pconcat_in_shape)

        if all(cls.is_constant(inp) for inp in inputs):
            LOG.info("reducing %s to a constant", node.name)
            value = np.stack([cls.get_constant(inp) for inp in inputs],
                             axis=axis)
            params = ConstantInputParameters(node.name,
                                             value=value,
                                             constant_store=G.constant_store)
        else:
            axis -= sum(1 if dim is None else 0
                        for dim in pconcat_out_shape[:axis:])
            params = ConcatParameters(node.name, axis=axis, axis_hint=None)

            # insert reshapes on each input to add concat axis
            for idx, inp in enumerate(inputs):
                rparams = ReshapeParameters(node.name + "_%s" % idx,
                                            old_shape=reshape_in_shape,
                                            shape=concat_in_shape)
                G.add_edge(
                    NNEdge(from_node=inp[0], to_node=rparams, from_idx=inp[1]))
                G.add_edge(NNEdge(from_node=rparams, to_node=params))
                if opts.get('load_quantization'):
                    G.quantization[NodeId(rparams)] = cls.load_tf_quantization(
                        [node.input[idx]], [node.input[idx]])

        if opts.get('load_quantization'):
            G.quantization[NodeId(params)] = cls.load_tf_quantization(
                node.input, node.output)

        all_nodes[node.output[0]] = (params, 0,
                                     ProvisionalDim(pconcat_out_shape))
        return params
 def _execute(self, node, G):
     info(f"{self}")
     params = ReshapeParameters(G.unique_name(f'{node.name}_reshape'),
                                old_shape=self.in_shape,
                                shape=self.out_shape)
     self.do_insert(node, G, params)
Пример #14
0
    def _common(cls, node, **kwargs):
        node_opts = node.get_options(FullyConnectedOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]

        x = inputs[0]
        x_shape = x[2].shape
        x_known_shape = x[2].known_shape
        inp_sz = np.prod(np.array(x_known_shape))
        weights = inputs[1]
        weights_node = weights[0]
        weights_shape = weights[2].shape
        assert len(weights_shape
                   ) == 2, f'bad filter shape {weights_shape} in {node.name}'
        out_c = weights_shape[0]
        batch_size = inp_sz // weights_shape[1]
        if batch_size > 1:
            filt_dim = FcFilterDim(weights_shape[0], weights_shape[1])
        else:
            filt_dim = FcFilterDim(weights_shape[0], *x_known_shape)

        node.input[1].used = True
        check(filt_dim.sz * batch_size == inp_sz,
              "filter doesn't match input size")

        if len(inputs) > 2:
            bias = inputs[2]
            bias_node = bias[0]
        else:
            bias_node = ConstantInputParameters(
                f'{node.name}_bias',
                dims=Dim.unnamed([out_c]),
                value=np.zeros([out_c], dtype=np.float32))  # TODO - check

        keep_dims = node_opts.KeepNumDims()

        if batch_size > 1:
            if keep_dims:
                raise ValueError(
                    f'keep dims on Fully Connected {node.name} with batch size > 1 is not supported'
                )

            # add a reshape to force the size of the input to batch * in_c
            input_shape = (batch_size, weights_shape[1])
            if x_known_shape != input_shape:
                rparams = ReshapeParameters(
                    G.unique_name(f'{node.name}_batch'),
                    old_shape=Dim.unnamed(x_known_shape),
                    shape=Dim.unnamed(input_shape))
                G.add_edge(
                    NNEdge(from_node=x[0],
                           to_node=rparams,
                           from_idx=x[1],
                           to_idx=0))
                link = (rparams, 0)
            else:
                link = x

            # the batched linear is transpose(weights . transpose(input))
            params = MatMulOpParameters(node.name)
            params.transpose_in = [None, (1, 0), None]
            params.transpose_out = [(1, 0)]
            cls.new_load_filter_parameters(G, params, weights_shape, 0,
                                           node.input[0], weights_node,
                                           bias_node, node.output[0], opts)
            G.add_edge(
                NNEdge(from_node=link[0],
                       to_node=params,
                       from_idx=link[1],
                       to_idx=1))
            G.add_edge(NNEdge(from_node=weights_node, to_node=params,
                              to_idx=0))
            G.add_edge(NNEdge(from_node=bias_node, to_node=params, to_idx=2))
            out_shape = [batch_size, out_c]
        else:
            # in_hint = [[str(i) for i in range(len(x_known_shape) - 1)] + ['c'],
            #            ['out_c', 'in_c'], ['out_c']]
            in_hint = [None, ['out_c', 'in_c'], ['out_c']]
            out_hint = in_hint.copy() if keep_dims else ['c']
            ker_in_order = None
            ker_out_order = None
            link = (x[0], x[1])

            params = FcParameters(node.name,
                                  filt=filt_dim,
                                  has_bias=True,
                                  in_dims_hint=in_hint,
                                  out_dims_hint=[out_hint],
                                  ker_in_order=ker_in_order,
                                  ker_out_order=ker_out_order,
                                  batch_size=batch_size,
                                  constant_store=G.constant_store,
                                  keep_dims=keep_dims)
            cls.new_load_filter_parameters(
                G, params, params.filter.actual_shape,
                params.filter.get_order_idx('out_c'), node.input[0],
                weights_node, bias_node, node.output[0], opts)

            G.add_edge(NNEdge(from_node=weights_node, to_node=params,
                              to_idx=1))
            G.add_edge(NNEdge(from_node=bias_node, to_node=params, to_idx=2))
            G.add_edge(
                NNEdge(from_node=link[0],
                       to_node=params,
                       from_idx=link[1],
                       to_idx=0))
            # handle keep_dims
            if x_shape[0] is None:
                if keep_dims:
                    out_shape = x_shape[:-1:] + [out_c]
                else:
                    out_shape = [None, out_c]
            else:
                if keep_dims:
                    out_shape = [None] + x_shape[1:-1:] + [out_c]
                else:
                    out_shape = [None, out_c]

        pout_dims = ProvisionalDim(out_shape)

        aparams = cls.fuse_activation(node_opts, node.name, params, **kwargs)
        all_nodes[node.output[0]] = (aparams, 0, pout_dims)
        return params
Пример #15
0
 def _execute(self, node, G):
     LOGL("%s", str(self))
     params = ReshapeParameters(G.unique_name(f'{node.name}'),
                                old_shape=self.in_shape,
                                shape=self.out_shape)
     G.insert_node_at_edge(params, self.edge, edge_class=NNEdge)
Пример #16
0
    def conv(cls, node, **kwargs):
        all_nodes = kwargs['all_nodes']
        G = kwargs['G']
        valid_name = kwargs['valid_name']
        inputs = [all_nodes[inp] for inp in node.input]
        # input N x C x H x W
        x = inputs[0]
        x_rank = len(x[2].shape)
        x_shape = x[2].shape
        spatial_size = x_rank - 2
        assert spatial_size == 2 or spatial_size == 1, "only 1D and 2D convolutions supported"

        # M x C/group x kH x kW
        weights_node = inputs[1][0]
        weights_node.name = f'{valid_name}_weights'
        weights = cls.get_constant(inputs[1])
        out_c = weights.shape[0]
        group = node.attrs.get("group", 1)
        in_c = x_shape[1]
        filt_in_c = in_c // group
        if in_c != weights.shape[1] * group:
            raise ValueError(
                f'node {valid_name} has incorrect input channel '
                f'dimension {in_c} expecting {weights.shape[1] * group}')
        if spatial_size == 1:
            filt_w = weights.shape[-1]
            filt_h = 1
            # create a new constant node since we are changing the shape
            weights = np.reshape(weights, (out_c, filt_in_c, filt_h, filt_w))
            weights_node = ConstantInputParameters(
                f'{valid_name}_weights',
                value=weights,
                dims=Dim.unnamed(weights.shape),
                constant_store=G.constant_store)
        else:
            filt_h = weights.shape[-2]
            filt_w = weights.shape[-1]
        h = 1 if spatial_size == 1 else x_shape[-2]
        w = x_shape[-1]

        filt_dim = Conv2DFilterDim(filt_h, filt_w, out_c, in_c=filt_in_c)
        filt_dim = filt_dim.impose_order(cls.ONNX_FILTER_ORDER)

        if len(inputs) > 2:
            biases_node = inputs[2][0]
            biases = cls.get_constant(inputs[2])
        else:
            biases = np.zeros([out_c], dtype=np.float32)
            biases_node = ConstantInputParameters(
                f'{valid_name}_biases',
                value=biases,
                dims=Dim.unnamed(biases.shape),
                constant_store=G.constant_store)

        dilations = cls.pad_start_with(node.attrs.get("dilations", []), [1], 2)
        strides = cls.pad_start_with(node.attrs.get("strides", []), [1], 2)
        pad_dim = cls.calc_pad_dim(node, 4)

        params = Conv2DParameters(valid_name,
                                  filt=filt_dim,
                                  stride=StrideDim(strides[0], strides[1]),
                                  dilation=DilationDim(dilations[0],
                                                       dilations[1]),
                                  groups=group,
                                  padding=pad_dim,
                                  has_bias=True,
                                  in_dims_hint=[['c', 'h', 'w'],
                                                cls.ONNX_FILTER_ORDER, ['c']],
                                  out_dims_hint=[['c', 'h', 'w']],
                                  constant_store=G.constant_store)

        in_dim = Dim.named_ordered(c=in_c, h=h, w=w)
        w_dim = Dim.named_ordered(out_c=out_c,
                                  in_c=filt_in_c,
                                  h=filt_h,
                                  w=filt_w)
        b_dim = Dim.named_ordered(c=out_c)
        out_dims = params.get_output_size([in_dim, w_dim, b_dim])
        G.add_edge(
            NNEdge(from_node=weights_node,
                   to_node=params,
                   from_idx=0,
                   to_idx=1))
        G.add_edge(
            NNEdge(from_node=biases_node, to_node=params, from_idx=0,
                   to_idx=2))
        if spatial_size == 1:
            oned_in_shape = [in_c, w]
            twod_in_shape = [in_c, 1, w]
            oned_out_shape = [out_dims[0].c, out_dims[0].w]
            r1_params = ReshapeParameters(f'{valid_name}_reshape2d',
                                          old_shape=Dim.unnamed(oned_in_shape),
                                          shape=Dim.unnamed(twod_in_shape))
            r2_params = ReshapeParameters(f'{valid_name}_reshape1d',
                                          old_shape=out_dims[0],
                                          shape=Dim.unnamed(oned_out_shape))
            G.add_edge(
                NNEdge(from_node=x[0],
                       to_node=r1_params,
                       from_idx=x[1],
                       to_idx=0))
            G.add_edge(
                NNEdge(from_node=r1_params,
                       to_node=params,
                       from_idx=0,
                       to_idx=0))
            G.add_edge(
                NNEdge(from_node=params,
                       to_node=r2_params,
                       from_idx=0,
                       to_idx=0))
            pout_dims = ProvisionalDim([x_shape[0]] + oned_out_shape)
            all_nodes[node.output[0]] = (r2_params, 0, pout_dims)
            return r2_params
        else:
            pout_dims = ProvisionalDim([x_shape[0]] + out_dims[0].shape)
            G.add_edge(
                NNEdge(from_node=x[0], to_node=params, from_idx=x[1],
                       to_idx=0))
            all_nodes[node.output[0]] = (params, 0, pout_dims)
            return params