Exemplo n.º 1
0
    def _common(cls, node, **kwargs):
        node_opts = node.get_options(FullyConnectedOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]

        x = inputs[0]
        x_shape = x[2].shape
        x_known_shape = x[2].known_shape
        inp_sz = np.prod(np.array(x_known_shape))
        weights = inputs[1]
        weights_shape = weights[2].shape
        out_c = weights_shape[0]

        filt_dim = FcFilterDim(weights_shape[0], *x_known_shape)
        node.input[1].used = True
        check(filt_dim.sz == inp_sz, "filter doesn't match input size")

        if len(node.input) > 2:
            node.input[2].used = True

        keep_dims = node_opts.KeepNumDims()

        in_hint = [str(i) for i in range(len(x_known_shape) - 1)] + ['c']
        out_hint = in_hint.copy() if keep_dims else ['c']

        params = FcParameters(node.name,
                              filt=filt_dim,
                              has_bias=True,
                              in_dims_hint=SparseList([in_hint]),
                              out_dims_hint=SparseList([out_hint]),
                              constant_store=G.constant_store,
                              keep_dims=keep_dims)

        if opts.get('load_dequantized'):
            cls.load_dequantized_filter_parameters(params, node.input)
        else:
            cls.load_filter_parameters(G, params, node.input, node.output,
                                       opts)

        if x_shape[0] is None:
            out_shape = x_shape[:-1:] + [out_c] if keep_dims else [
                x_shape[0], out_c
            ]
        else:
            out_shape = x_known_shape[:-1:] + [out_c] if keep_dims else [out_c]
        pout_dims = ProvisionalDim(out_shape)

        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))
        aparams = cls.fuse_activation(node_opts, node.name, params, **kwargs)
        all_nodes[node.output[0]] = (aparams, 0, pout_dims)
        return params
    def _common(cls, node: TFLiteNode, **kwargs):
        node_opts = node.get_options(DepthwiseConv2DOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]

        x = inputs[0]
        x = cls.remove_known_batch_dimension(G, x, node)
        x_shape = x[2].shape
        in_b, h, w, in_c = tuple(x_shape)

        filt = inputs[1]
        weights_node = filt[0]
        filt_shape = filt[2].shape
        # ['in_c', 'h', 'w', 'out_c']
        filt_in_c, filt_h, filt_w, filt_out_c = tuple(filt_shape)

        # get filter dimensions
        if filt_h > h or filt_w > w:
            LOG.warning(
                "Filter %s of shape [%dx%d] is bigger than input of shape [%dx%d]",
                node.name, filt_h, filt_w, h, w)

        filt_dim = Conv2DFilterDim(filt_h, filt_w, filt_out_c, in_c=filt_in_c)
        filt_dim = filt_dim.impose_order(cls.TF_LITE_DW_FILTER_ORDER)

        # multiplier should match filter
        check(filt_dim.out_c == node_opts.DepthMultiplier() * in_c,
              "invalid multiplier")

        groups = filt_dim.out_c // node_opts.DepthMultiplier()

        # compute padding
        pad = cls.get_tf_padding(node_opts.Padding())

        # does it have biases
        if len(inputs) > 2:
            bias = inputs[2]
            bias_node = bias[0]
        else:
            bias_node = ConstantInputParameters(
                f'{node.name}_bias',
                dims=Dim.unnamed([filt_out_c]),
                value=np.zeros([filt_out_c], dtype=np.float32))  # TODO - check

        # TFLITE produces single channel input DW convolutions with the
        # multiplier equal to the number of out channels. This is just
        # a normal convolution and since we don't handle the channel
        # multiplier at present (but can) just convert them to normal
        # convolutions
        convert_to_conv = in_c == 1 and groups == 1

        if convert_to_conv:
            filt_dim.impose_order(cls.TF_LITE_FILTER_ORDER)
            params = Conv2DParameters(
                node.name,
                filt=filt_dim,
                stride=StrideDim(node_opts.StrideH(), node_opts.StrideW()),
                dilation=DilationDim(node_opts.DilationHFactor(),
                                     node_opts.DilationWFactor()),
                padding=pad,
                has_bias=True,
                in_dims_hint=[['h', 'w', 'c'],
                              cls.TF_LITE_FILTER_ORDER.copy(), ['out_c']],
                out_dims_hint=[['h', 'w', 'c']])
        else:
            filt_dim.impose_order(cls.TF_LITE_DW_FILTER_ORDER)
            params = Conv2DParameters(
                node.name,
                filt=filt_dim,
                stride=StrideDim(node_opts.StrideH(), node_opts.StrideW()),
                dilation=DilationDim(node_opts.DilationHFactor(),
                                     node_opts.DilationWFactor()),
                padding=pad,
                groups=groups,
                multiplier=node_opts.DepthMultiplier(),
                has_bias=True,
                tf_depthwise=True,
                in_dims_hint=[['h', 'w', 'c'],
                              cls.TF_LITE_DW_FILTER_ORDER.copy(), ['out_c']],
                out_dims_hint=[['h', 'w', 'c']])

        G.add_edge(NNEdge(from_node=weights_node, to_node=params, to_idx=1))
        G.add_edge(NNEdge(from_node=bias_node, to_node=params, to_idx=2))
        cls.new_load_filter_parameters(G,
                                       params,
                                       params.filter.actual_shape,
                                       params.filter.get_order_idx('out_c'),
                                       node.input[0],
                                       weights_node,
                                       bias_node,
                                       node.output[0],
                                       opts,
                                       dw_to_pw=convert_to_conv)

        in_dim = Dim.named_ordered(h=h, w=w, c=in_c)
        out_dims = params.get_output_size(
            [in_dim,
             Dim.unnamed(filt_dim.shape),
             Dim.unnamed([filt_out_c])])
        pout_dims = ProvisionalDim([in_b] + out_dims[0].shape)
        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))
        params = cls.fuse_activation(node_opts, node.name, params, **kwargs)
        all_nodes[node.output[0]] = (params, 0, pout_dims)
        return params
Exemplo n.º 3
0
    def _common(cls, node, **kwargs):
        node_opts = node.get_options(FullyConnectedOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]

        x = inputs[0]
        x_shape = x[2].shape
        x_known_shape = x[2].known_shape
        inp_sz = np.prod(np.array(x_known_shape))
        weights = inputs[1]
        weights_node = weights[0]
        weights_shape = weights[2].shape
        out_c = weights_shape[0]

        filt_dim = FcFilterDim(weights_shape[0], *x_known_shape)
        node.input[1].used = True
        check(filt_dim.sz == inp_sz, "filter doesn't match input size")

        if len(inputs) > 2:
            bias = inputs[2]
            bias_node = bias[0]
        else:
            bias_node = ConstantInputParameters(
                f'{node.name}_bias',
                dims=Dim.unnamed([out_c]),
                value=np.zeros([out_c], dtype=np.float32))  # TODO - check

        keep_dims = node_opts.KeepNumDims()

        in_hint = [str(i) for i in range(len(x_known_shape) - 1)] + ['c']
        out_hint = in_hint.copy() if keep_dims else ['c']

        params = FcParameters(node.name,
                              filt=filt_dim,
                              has_bias=True,
                              in_dims_hint=SparseList(
                                  [in_hint, ['out_c', 'in_c'], ['out_c']]),
                              out_dims_hint=SparseList([out_hint]),
                              constant_store=G.constant_store,
                              keep_dims=keep_dims)

        G.add_edge(NNEdge(from_node=weights_node, to_node=params, to_idx=1))
        G.add_edge(NNEdge(from_node=bias_node, to_node=params, to_idx=2))

        cls.new_load_filter_parameters(G, params, node.input[0], weights_node,
                                       bias_node, node.output[0], opts)

        # if opts.get('load_dequantized'):
        #     weights_node.value, bias_node.value = cls.load_dequantized_filter_parameters(
        #         node.input, bias_node.value)
        # else:
        #     qrec, weights_node.value, bias_node.value = cls.load_filter_parameters(
        #         G, params, node.input, bias_node.value, node.output, opts)
        #     if qrec:
        #         G.quantization[NodeId(weights_node)].out_qs[0] = qrec.in_qs[1]
        #         G.quantization[NodeId(bias_node)].out_qs[0] = qrec.in_qs[2]

        if x_shape[0] is None:
            out_shape = x_shape[:-1:] + [out_c] if keep_dims else [
                x_shape[0], out_c
            ]
        else:
            out_shape = x_known_shape[:-1:] + [out_c] if keep_dims else [out_c]
        pout_dims = ProvisionalDim(out_shape)

        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))
        aparams = cls.fuse_activation(node_opts, node.name, params, **kwargs)
        all_nodes[node.output[0]] = (aparams, 0, pout_dims)
        return params
Exemplo n.º 4
0
    def _common(cls, node: TFLiteNode, **kwargs):
        node_opts = node.get_options(PackOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]
        inp_shapes = [input[2].shape for input in inputs]

        values_count = node_opts.ValuesCount()
        check(
            len(inputs) == values_count,
            "invalid tflite file - values count not equal to inputs")

        buffer_idxes = [tensor.buffer_idx for tensor in node.input]
        check(
            len(set(buffer_idxes)) == len(buffer_idxes),
            "packs with multiple versions of the same input are not supported. This is normally a graph design problem."
        )

        axis = node_opts.Axis()
        dimension_size = len(inp_shapes)
        if axis < 0:
            axis += dimension_size

        check(all(shape == inp_shapes[0] for shape in inp_shapes[1::]),
              "invalid tflite file - pack inputs not the same")

        # prepare shapes of all tensors
        pconcat_out_shape = inp_shapes[0].copy()
        pconcat_out_shape.insert(axis, values_count)

        pconcat_in_shape = inp_shapes[0].copy()
        pconcat_in_shape.insert(axis, 1)

        preshape_in_shape = inp_shapes[0].copy()

        # remove nones from constants
        cls.remove_none_from_constants(inputs, preshape_in_shape)

        # remove nones from reshape shapes
        reshape_in_shape = cls.remove_unspecified_dim(preshape_in_shape)
        concat_in_shape = cls.remove_unspecified_dim(pconcat_in_shape)

        if all(cls.is_constant(inp) for inp in inputs):
            LOG.info("reducing %s to a constant", node.name)
            value = np.stack([cls.get_constant(inp) for inp in inputs],
                             axis=axis)
            params = ConstantInputParameters(node.name,
                                             value=value,
                                             constant_store=G.constant_store)
        else:
            axis -= sum(1 if dim is None else 0
                        for dim in pconcat_out_shape[:axis:])
            params = ConcatParameters(node.name, axis=axis, axis_hint=None)

            # insert reshapes on each input to add concat axis
            for idx, inp in enumerate(inputs):
                rparams = ReshapeParameters(node.name + "_%s" % idx,
                                            old_shape=reshape_in_shape,
                                            shape=concat_in_shape)
                G.add_edge(
                    NNEdge(from_node=inp[0], to_node=rparams, from_idx=inp[1]))
                G.add_edge(NNEdge(from_node=rparams, to_node=params))
                if opts.get('load_quantization'):
                    G.quantization[NodeId(rparams)] = cls.load_tf_quantization(
                        [node.input[idx]], [node.input[idx]])

        if opts.get('load_quantization'):
            G.quantization[NodeId(params)] = cls.load_tf_quantization(
                node.input, node.output)

        all_nodes[node.output[0]] = (params, 0,
                                     ProvisionalDim(pconcat_out_shape))
        return params
Exemplo n.º 5
0
    def _common(cls, node, **kwargs):
        node_opts = node.get_options(FullyConnectedOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]

        x = inputs[0]
        x_shape = x[2].shape
        x_known_shape = x[2].known_shape
        inp_sz = np.prod(np.array(x_known_shape))
        weights = inputs[1]
        weights_node = weights[0]
        weights_shape = weights[2].shape
        assert len(weights_shape
                   ) == 2, f'bad filter shape {weights_shape} in {node.name}'
        out_c = weights_shape[0]
        batch_size = inp_sz // weights_shape[1]
        if batch_size > 1:
            filt_dim = FcFilterDim(weights_shape[0], weights_shape[1])
        else:
            filt_dim = FcFilterDim(weights_shape[0], *x_known_shape)

        node.input[1].used = True
        check(filt_dim.sz * batch_size == inp_sz,
              "filter doesn't match input size")

        if len(inputs) > 2:
            bias = inputs[2]
            bias_node = bias[0]
        else:
            bias_node = ConstantInputParameters(
                f'{node.name}_bias',
                dims=Dim.unnamed([out_c]),
                value=np.zeros([out_c], dtype=np.float32))  # TODO - check

        keep_dims = node_opts.KeepNumDims()

        if batch_size > 1:
            if keep_dims:
                raise ValueError(
                    f'keep dims on Fully Connected {node.name} with batch size > 1 is not supported'
                )

            # add a reshape to force the size of the input to batch * in_c
            input_shape = (batch_size, weights_shape[1])
            if x_known_shape != input_shape:
                rparams = ReshapeParameters(
                    G.unique_name(f'{node.name}_batch'),
                    old_shape=Dim.unnamed(x_known_shape),
                    shape=Dim.unnamed(input_shape))
                G.add_edge(
                    NNEdge(from_node=x[0],
                           to_node=rparams,
                           from_idx=x[1],
                           to_idx=0))
                link = (rparams, 0)
            else:
                link = x

            # the batched linear is transpose(weights . transpose(input))
            params = MatMulOpParameters(node.name)
            params.transpose_in = [None, (1, 0), None]
            params.transpose_out = [(1, 0)]
            cls.new_load_filter_parameters(G, params, weights_shape, 0,
                                           node.input[0], weights_node,
                                           bias_node, node.output[0], opts)
            G.add_edge(
                NNEdge(from_node=link[0],
                       to_node=params,
                       from_idx=link[1],
                       to_idx=1))
            G.add_edge(NNEdge(from_node=weights_node, to_node=params,
                              to_idx=0))
            G.add_edge(NNEdge(from_node=bias_node, to_node=params, to_idx=2))
            out_shape = [batch_size, out_c]
        else:
            # in_hint = [[str(i) for i in range(len(x_known_shape) - 1)] + ['c'],
            #            ['out_c', 'in_c'], ['out_c']]
            in_hint = [None, ['out_c', 'in_c'], ['out_c']]
            out_hint = in_hint.copy() if keep_dims else ['c']
            ker_in_order = None
            ker_out_order = None
            link = (x[0], x[1])

            params = FcParameters(node.name,
                                  filt=filt_dim,
                                  has_bias=True,
                                  in_dims_hint=in_hint,
                                  out_dims_hint=[out_hint],
                                  ker_in_order=ker_in_order,
                                  ker_out_order=ker_out_order,
                                  batch_size=batch_size,
                                  constant_store=G.constant_store,
                                  keep_dims=keep_dims)
            cls.new_load_filter_parameters(
                G, params, params.filter.actual_shape,
                params.filter.get_order_idx('out_c'), node.input[0],
                weights_node, bias_node, node.output[0], opts)

            G.add_edge(NNEdge(from_node=weights_node, to_node=params,
                              to_idx=1))
            G.add_edge(NNEdge(from_node=bias_node, to_node=params, to_idx=2))
            G.add_edge(
                NNEdge(from_node=link[0],
                       to_node=params,
                       from_idx=link[1],
                       to_idx=0))
            # handle keep_dims
            if x_shape[0] is None:
                if keep_dims:
                    out_shape = x_shape[:-1:] + [out_c]
                else:
                    out_shape = [None, out_c]
            else:
                if keep_dims:
                    out_shape = [None] + x_shape[1:-1:] + [out_c]
                else:
                    out_shape = [None, out_c]

        pout_dims = ProvisionalDim(out_shape)

        aparams = cls.fuse_activation(node_opts, node.name, params, **kwargs)
        all_nodes[node.output[0]] = (aparams, 0, pout_dims)
        return params
Exemplo n.º 6
0
    def _common(cls, node: TFLiteNode, **kwargs):
        node_opts = node.get_options(DepthwiseConv2DOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]

        x = inputs[0]
        x_shape = x[2].shape
        in_b, h, w, in_c = tuple(x_shape)

        filt = inputs[1]
        filt_tensor = node.input[1]
        filt_shape = filt[2].shape
        # ['in_c', 'h', 'w', 'out_c']
        filt_in_c, filt_h, filt_w, filt_out_c = tuple(filt_shape)

        # get filter dimensions
        filt_tensor.used = True
        if filt_h > h or filt_w > w:
            LOG.warning(
                "Filter %s of shape [%dx%d] is bigger than input of shape [%dx%d]",
                node.name, filt_h, filt_w, h, w)

        filt_dim = Conv2DFilterDim(filt_h, filt_w, filt_out_c, in_c=filt_in_c)
        filt_dim = filt_dim.impose_order(cls.TF_LITE_DW_FILTER_ORDER)

        # multiplier should match filter
        check(filt_dim.out_c == node_opts.DepthMultiplier() * in_c,
              "invalid multiplier")

        groups = filt_dim.out_c // node_opts.DepthMultiplier()

        # compute padding
        pad = cls.get_tf_padding(node_opts.Padding())

        # does it have biases
        has_bias = len(inputs) > 2
        if has_bias:
            node.input[2].used = True

        # TFLITE produces single channel input DW convolutions with the
        # multiplier equal to the number of out channels. This is just
        # a normal convolution and since we don't handle the channel
        # multiplier at present (but can) just convert them to normal
        # convolutions
        convert_to_conv = in_c == 1 and groups == 1

        if convert_to_conv:
            filt_dim.impose_order(cls.TF_LITE_FILTER_ORDER)
            params = Conv2DParameters(
                node.name,
                filt=filt_dim,
                stride=StrideDim(node_opts.StrideH(), node_opts.StrideW()),
                dilation=DilationDim(node_opts.DilationHFactor(),
                                     node_opts.DilationWFactor()),
                padding=pad,
                has_bias=has_bias,
                in_dims_hint=SparseList([['h', 'w', 'c']]),
                out_dims_hint=SparseList([['h', 'w', 'c']]),
                constant_store=G.constant_store)
        else:
            filt_dim.impose_order(cls.TF_LITE_DW_FILTER_ORDER)
            params = Conv2DParameters(
                node.name,
                filt=filt_dim,
                stride=StrideDim(node_opts.StrideH(), node_opts.StrideW()),
                dilation=DilationDim(node_opts.DilationHFactor(),
                                     node_opts.DilationWFactor()),
                padding=pad,
                groups=groups,
                multiplier=node_opts.DepthMultiplier(),
                has_bias=has_bias,
                tf_depthwise=True,
                in_dims_hint=SparseList([['h', 'w', 'c']]),
                out_dims_hint=SparseList([['h', 'w', 'c']]),
                constant_store=G.constant_store)

        if opts.get('load_dequantized'):
            cls.load_dequantized_filter_parameters(params,
                                                   node.input,
                                                   convert_to_conv,
                                                   is_dw=True)
        else:
            cls.load_filter_parameters(G,
                                       params,
                                       node.input,
                                       node.output,
                                       opts,
                                       converted_to_conv=convert_to_conv)

        in_dim = Dim.named_ordered(h=h, w=w, c=in_c)
        out_dims = params.get_output_size([in_dim])
        pout_dims = ProvisionalDim([in_b] + out_dims[0].shape)
        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))
        params = cls.fuse_activation(node_opts, node.name, params, **kwargs)
        all_nodes[node.output[0]] = (params, 0, pout_dims)
        return params
Exemplo n.º 7
0
    def _common(cls, node, **kwargs):
        node_opts = node.get_options(FullyConnectedOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        keep_dims = node_opts.KeepNumDims()
        # check(not keep_dims,
        #       f'keep dims on Fully Connected {node.name} is not supported')

        inputs = [all_nodes[t] if t is not None else None for t in node.input]

        x = inputs[0]
        x_shape = x[2]
        x_known_shape = x_shape.known_shape
        inp_sz = np.prod(np.array(x_known_shape))
        weights = inputs[1]
        weights_node = weights[0]
        weights_shape = weights[2].shape
        check(
            len(weights_shape) == 2,
            f'bad filter shape {weights_shape} in {node.name}')
        out_c = weights_shape[0]
        batch_size = inp_sz // weights_shape[1]

        keep_dims = node_opts.KeepNumDims()
        if keep_dims:
            if x_shape.shape[-1] != weights_shape[1]:
                raise ValueError(
                    f'Keep dims set on {node.name} but last input dimension does not match weights'
                )
            out_shape = x_shape.shape.copy()
            out_shape[-1] = out_c
        elif batch_size > 1:
            out_shape = (batch_size, out_c)
        else:
            out_shape = (None, out_c)
        real_out_shape = tuple(dim for dim in out_shape if dim is not None)

        filt_dim = FcFilterDim(weights_shape[0], weights_shape[1])

        node.input[1].used = True
        check(filt_dim.sz * batch_size == inp_sz,
              "filter doesn't match input size")

        if len(inputs) > 2 and inputs[2] is not None:
            bias = inputs[2]
            bias_node = bias[0]
        else:
            bias_node = ConstantInputParameters(f'{node.name}_bias',
                                                dims=Dim.unnamed([out_c]),
                                                value=np.zeros(
                                                    [out_c], dtype=np.float32))

        if batch_size > 1:
            # add a reshape to force the size of the input to batch * in_c
            input_shape = (batch_size, weights_shape[1])
            if x_known_shape != input_shape:
                rparams = ReshapeParameters(
                    G.unique_name(f'{node.name}_batch'),
                    old_shape=Dim.unnamed(x_known_shape),
                    shape=Dim.unnamed(input_shape))
                G.add_edge(
                    NNEdge(from_node=x[0],
                           to_node=rparams,
                           from_idx=x[1],
                           to_idx=0))
                link = (rparams, 0)
            else:
                link = x

            # the batched linear is ([NxM] . [MxK]) + [K]
            params = MatMulTransposedParameters(node.name)
            cls.new_load_filter_parameters(G, params, weights_shape, 0,
                                           node.input[0], weights_node,
                                           bias_node, node.output[0], opts)
            trans2 = TransposeParameters(G.unique_name(f'{node.name}_tin2'),
                                         transpose=(1, 0))
            G.add_edge(
                NNEdge(from_node=link[0], to_node=params, from_idx=link[1]))
            G.add_edge(NNEdge(from_node=weights_node, to_node=params,
                              to_idx=1))
            #G.add_edge(NNEdge(from_node=trans2, to_node=params, to_idx=1))
            G.add_edge(NNEdge(from_node=bias_node, to_node=params, to_idx=2))
            fc_shape = (batch_size, out_c)
        else:
            ker_in_order = None
            ker_out_order = None
            link = (x[0], x[1])

            params = FcParameters(node.name,
                                  filt=filt_dim,
                                  has_bias=True,
                                  ker_in_order=ker_in_order,
                                  ker_out_order=ker_out_order,
                                  batch_size=batch_size,
                                  keep_dims=keep_dims)
            cls.new_load_filter_parameters(
                G, params, params.filter.actual_shape,
                params.filter.get_order_idx('out_c'), node.input[0],
                weights_node, bias_node, node.output[0], opts)

            G.add_edge(NNEdge(from_node=weights_node, to_node=params,
                              to_idx=1))
            G.add_edge(NNEdge(from_node=bias_node, to_node=params, to_idx=2))
            G.add_edge(
                NNEdge(from_node=link[0],
                       to_node=params,
                       from_idx=link[1],
                       to_idx=0))
            fc_shape = (out_c, )

        pout_dims = ProvisionalDim(out_shape)
        aparams = cls.fuse_activation(node_opts, node.name, params, **kwargs)

        if real_out_shape != fc_shape:
            rparams = ReshapeParameters(G.unique_name(f'{node.name}_keepdims'),
                                        old_shape=fc_shape,
                                        shape=real_out_shape)
            G.add_edge(NNEdge(from_node=aparams, to_node=rparams))
            aparams = rparams

        all_nodes[node.output[0]] = (aparams, 0, pout_dims)
        return params