Пример #1
0
 def _common(cls, node, **kwargs):
     all_nodes = kwargs['all_nodes']
     G = kwargs['G']
     opts = kwargs['opts']
     qrec_class = kwargs.get('qrec_class')
     params_args = kwargs.get('params_args', {})
     constant_operation = kwargs.get('constant_operation')
     inputs = [all_nodes[inp] for inp in node.input]
     assert len(inputs) == 2
     if all(cls.is_constant(inp) for inp in inputs) and constant_operation:
         LOG.info("reducing %s to a constant", node.name)
         values = [cls.get_constant(inp) for inp in inputs]
         output_shapes = cls.implied_broadcast(inputs)
         params = ConstantInputParameters(node.name, value=constant_operation(*values),
                                          dims=Dim.unnamed(output_shapes[0].known_shape), constant_store=G.constant_store)
     else:
         params = kwargs['params_class'](node.name, **params_args)
         output_shapes = cls.implied_broadcast(inputs)
         shapes = []
         for idx, inp in enumerate(inputs):
             G.add_edge(NNEdge(from_node=inp[0], to_node=params, from_idx=inp[1], to_idx=idx))
             shapes.append(inp[2].known_shape)
         if isinstance(params, Broadcastable):
             params.set_broadcast(shapes)
     if opts.get('load_quantization'):
         G.quantization[NodeId(params)] = cls.load_tf_quantization(
             node.input, node.output, qrec_class=qrec_class)
     all_nodes[node.output[0]] = (params, 0, output_shapes[0])
     return params
Пример #2
0
    def _common(cls, node: TFLiteNode, **kwargs):
        node_opts = node.get_options(CastOptions)
        G = kwargs['G']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]
        x = inputs[0]

        if node_opts:
            in_dtype = TFLiteTensorWrapper.TF_TO_NUMPY_TYPE[
                node_opts.InDataType()]
            out_dtype = TFLiteTensorWrapper.TF_TO_NUMPY_TYPE[
                node_opts.OutDataType()]
        else:
            in_dtype = out_dtype = None
        if cls.is_constant(x):
            LOG.info("reducing %s to a constant", node.name)
            val = cls.get_constant(x)
            if out_dtype:
                val = val.astype(out_dtype)
            params = ConstantInputParameters(node.name, value=val)
        else:
            params = CastParameters(node.name,
                                    in_dtype=in_dtype,
                                    out_dtype=out_dtype)
            G.add_edge(
                NNEdge(from_node=x[0], to_node=params, from_idx=x[1],
                       to_idx=0))

        all_nodes[node.output[0]] = (params, 0, deepcopy(x[2]))
        return params
Пример #3
0
    def _common(cls, node: TFLiteNode, **kwargs):
        node_opts = node.get_options(ReshapeOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]
        x = inputs[0]
        x_shape = x[2].shape

        # TF2 seems to use the second input whereas TF1 uses the opts
        new_shape = None
        if node_opts:
            new_shape = list(node_opts.NewShapeAsNumpy())
        elif len(inputs) > 1:
            set_shape_tensor = list(cls._verify_constant(inputs[1]))
            node.input[1].used = True
            new_shape = list(set_shape_tensor)
        else:
            ValueError(
                f"Cannot asses new_shape for Reshape Parameter: {node.name}")

        if -1 in new_shape:
            new_shape_size = reduce(lambda x, y: x * 1
                                    if y == -1 else x * y, new_shape, 1)
            inp_size = reduce(lambda x, y: x * y
                              if y is not None else x, x_shape, 1)
            new_shape[new_shape.index(-1)] = inp_size // new_shape_size

        if None in x_shape:
            if 1 in new_shape:
                old_batch_dim = x_shape.index(None)
                new_batch_dim = new_shape.index(1)
                if old_batch_dim != new_batch_dim:
                    LOG.info(
                        "node %s moved batch dimension for axis %s to axis %s",
                        node.name, old_batch_dim, new_batch_dim)
                new_shape[new_batch_dim] = None
            else:
                raise ValueError(
                    "unable to determine movement of unspcified axis in node %s"
                    % node.name)

        pnew_shape = ProvisionalDim(new_shape)
        old_shape = Dim.unnamed(cls.remove_unspecified_dim(x_shape),
                                is_ordered=True)
        new_shape = Dim.unnamed(cls.remove_unspecified_dim(new_shape),
                                is_ordered=True)

        params = ReshapeParameters(node.name,
                                   old_shape=old_shape,
                                   shape=new_shape)

        if opts.get('load_quantization'):
            G.quantization[NodeId(params)] = cls.load_tf_quantization(
                [node.input[0]], node.output)
        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))
        all_nodes[node.output[0]] = (params, 0, pnew_shape)
        return params
Пример #4
0
    def _common(cls, node: TFLiteNode, **kwargs):
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]
        x = inputs[0]
        x_shape = x[2].shape

        new_axes = {}
        for idx, dim in enumerate(x_shape):
            if dim is not None:
                new_axes[idx] = len(new_axes)
        ptranspose = cls._verify_constant(inputs[1])
        pout_shape = [x_shape[dim] for dim in ptranspose]
        transpose = [new_axes[axis] for axis in ptranspose if x_shape[axis] is not None]
        node.input[1].used = True

        if cls.is_constant(x):
            LOG.info("reducing %s to a constant", node.name)
            val = np.transpose(cls.get_constant(x), ptranspose)
            params = ConstantInputParameters(node.name, value=np.transpose(val, ptranspose),
                                             dims=Dim.unnamed(val.shape), constant_store=G.constant_store)
        else:
            params = TransposeParameters(node.name, transpose=transpose)

        if opts.get('load_quantization'):
            G.quantization[NodeId(params)] = cls.load_tf_quantization([node.input[0]], node.output)
        G.add_edge(NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))
        all_nodes[node.output[0]] = (params, 0, ProvisionalDim(pout_shape))
        return params
Пример #5
0
    def common_quantize(cls, in_qtype, out_qtype, node, **kwargs):
        all_nodes = kwargs['all_nodes']
        opts = kwargs['opts']
        G = kwargs['G']
        inputs = [all_nodes[t] for t in node.input]
        x = inputs[0]
        if cls.is_constant(x):
            LOG.info("reducing %s to a constant", node.name)
            if out_qtype:
                val = x[0].value_as(out_qtype)
            else:
                val = cls.get_constant(x)
            params = ConstantInputParameters(node.name,
                                             value=val,
                                             dims=Dim.unnamed(val.shape),
                                             qtype=out_qtype,
                                             constant_store=G.constant_store)
            if opts.get('load_quantization'):
                G.quantization[NodeId(params)] = MultQuantizationRecord(
                    in_qs=[out_qtype], out_qs=[out_qtype])
        else:
            params = QuantizeParameters(node.name, from_qtype=in_qtype, to_qtype=out_qtype)
            G.add_edge(NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))

            if opts.get('load_quantization'):
                G.quantization[NodeId(params)] = MultQuantizationRecord(
                    in_qs=[in_qtype], out_qs=[out_qtype])
        all_nodes[node.output[0]] = (params, 0, deepcopy(x[2]))
        return params
Пример #6
0
    def _common(cls, node: TFLiteNode, **kwargs):
        node_opts = node.get_options(ConcatenationOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]
        inp_shapes = [input[2].shape for input in inputs]

        buffer_idxes = [tensor.buffer_idx for tensor in node.input]
        non_zero_idxes = [idx for idx in buffer_idxes if idx != 0]
        if len(set(non_zero_idxes)) != len(non_zero_idxes):
            raise NotImplementedError(
                "concats with multiple versions of the same input are not supported. "
                "This is normally a graph design problem.")

        axis = node_opts.Axis()
        if any(inp_shape[axis] is None for inp_shape in inp_shapes):
            raise ValueError("concat on undefined axis in node %s" % node.name)

        def red_func(x, y):
            return y.copy() if x is None else [
                (elem if y[idx] is not None and elem is not None else None)
                if idx != axis else elem + y[axis]
                for idx, elem in enumerate(x)
            ]

        pout_shape = reduce(red_func, inp_shapes)

        if all(cls.is_constant(inp) for inp in inputs):
            # cls.remove_none_from_constants(inputs, pout_shape)
            LOG.info("reducing %s to a constant", node.name)
            value = np.concatenate([cls.get_constant(inp) for inp in inputs],
                                   axis=axis)
            params = ConstantInputParameters(node.name,
                                             value=value,
                                             constant_store=G.constant_store)
        else:
            axis -= sum(1 if dim is None else 0 for dim in pout_shape[:axis:])
            params = ConcatParameters(node.name, axis=axis, axis_hint=None)

            for idx, inp in enumerate(inputs):
                inp_node, inp_idx = cls._maybe_insert_reshape(
                    G, inp, inp_shapes[idx], pout_shape)
                G.add_edge(
                    NNEdge(from_node=inp_node,
                           to_node=params,
                           from_idx=inp_idx,
                           to_idx=idx))
        if opts.get('load_quantization'):
            G.quantization[NodeId(params)] = cls.load_tf_quantization(
                node.input, node.output)
        cls.fuse_activation(node_opts, node.name, params, **kwargs)
        all_nodes[node.output[0]] = (params, 0, ProvisionalDim(pout_shape))
        return params
Пример #7
0
 def _common(cls, node, **kwargs):
     all_nodes = kwargs['all_nodes']
     G = kwargs['G']
     opts = kwargs['opts']
     node_opts = kwargs.get("node_opts", None)
     params_args = kwargs.get('params_args', {})
     constant_operation = kwargs.get('constant_operation')
     inputs = [all_nodes[inp] for inp in node.input]
     assert len(inputs) == 2
     if all(cls.is_constant(inp) for inp in inputs) and constant_operation:
         LOG.info("reducing %s to a constant", node.name)
         values = [cls.get_constant(inp) for inp in inputs]
         output_shapes = cls.implied_broadcast(inputs)
         params = ConstantInputParameters(node.name,
                                          value=constant_operation(*values),
                                          dims=Dim.unnamed(
                                              output_shapes[0].known_shape),
                                          constant_store=G.constant_store)
     else:
         params = kwargs['params_class'](node.name, **params_args)
         output_shapes = cls.implied_broadcast(inputs)
         shapes = []
         for idx, inp in enumerate(inputs):
             G.add_edge(
                 NNEdge(from_node=inp[0],
                        to_node=params,
                        from_idx=inp[1],
                        to_idx=idx))
             shapes.append(inp[2].known_shape)
         if isinstance(params, Broadcastable):
             for idx, shape in enumerate(shapes.copy()):
                 len_diff = len(shape) - len(output_shapes[0].known_shape)
                 if len_diff > 0:
                     if not all(dim is None or dim == 1
                                for dim in shape[:len_diff:]):
                         in_shapes = ",".join(
                             str(shape) for shape in shapes)
                         raise ValueError(
                             f'strange broadcast {in_shapes} -> {output_shapes[0].shape}'
                         )
                     shapes[idx] = shape[len_diff::]
             params.set_broadcast(shapes)
     if opts.get('load_quantization'):
         G.quantization[NodeId(params)] = cls.load_tf_quantization(
             node.input, node.output)
     if node_opts is not None:
         params = cls.fuse_activation(node_opts, node.name, params,
                                      **kwargs)
     all_nodes[node.output[0]] = (params, 0, output_shapes[0])
     return params
Пример #8
0
    def _common(cls, node: TFLiteNode, **kwargs):
        node_opts = node.get_options(GatherOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]
        x = inputs[0]
        x_shape = x[2].shape
        gather_tensor = list(cls._verify_constant(inputs[1]))
        axis = node_opts.Axis()

        if cls._is_constant(x):
            inp = cls._get_constant(x)
            inp = np.take(inp, gather_tensor, axis=axis)
            pout_shape = inp.shape
            LOG.info("reducing %s to a constant", node.name)
            params = ConstantInputParameters(node.name,
                                             value=inp,
                                             dims=pout_shape)
        else:
            if x_shape[axis] is None:
                raise ValueError(
                    f'GATHER {node.name} on batch axis not supported')
            slices = [sequence_to_slice(elem) for elem in gather_tensor]
            strides = set([
                slices[idx + 1][0] - slice[0]
                for idx, slice in enumerate(slices[:-1:])
            ])
            lengths = set([abs(slice[1] - slice[0]) for slice in slices])
            if len(strides) != 1 or len(lengths) != 1:
                raise ValueError(f'Irregular GATHER {node.name} not supported')
            out_len = sum(len(slice) for slice in slices)
            pout_shape = x_shape.copy()
            pout_shape[axis] = out_len
            axis -= sum(1 if dim is None else 0 for dim in x_shape[:axis:])
            LOG.info("reducing %s to an overlapped copy", node.name)
            params = GatherParameters(node.name,
                                      indices=gather_tensor,
                                      axis=axis)

        if opts.get('load_quantization'):
            G.quantization[NodeId(params)] = cls.load_tf_quantization(
                node.input, node.output)
        all_nodes[node.output[0]] = (params, 0, ProvisionalDim(pout_shape))
        return params
Пример #9
0
    def _common(cls, node: TFLiteNode, **kwargs):
        node_opts = node.get_options(SqueezeOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]
        x = inputs[0]
        x_shape = x[2].shape

        if node_opts.SqueezeDimsIsNone():
            new_shape = [dim for dim in x_shape if dim != 1]
        else:
            axes = node_opts.SqueezeDimsAsNumpy()
            axes = np.where(axes < 0, axes + len(x_shape), axes)
            if np.any(np.array(x_shape)[axes] not in [None, 1]):
                raise ValueError(f'invalid expand dims > 1 {node.name}')
            new_shape = [
                dim for idx, dim in enumerate(x_shape) if idx not in axes
            ]

        if cls.is_constant(x):
            LOG.info("reducing %s to a constant", node.name)
            val = np.reshape(cls.get_constant(x), new_shape)
            params = ConstantInputParameters(node.name,
                                             value=val,
                                             dims=Dim.unnamed(val.shape))
        else:
            pnew_shape = ProvisionalDim(new_shape)
            old_shape = Dim.unnamed(cls.remove_unspecified_dim(x_shape),
                                    is_ordered=True)
            new_shape = Dim.unnamed(cls.remove_unspecified_dim(new_shape),
                                    is_ordered=True)
            params = ReshapeParameters(node.name,
                                       old_shape=old_shape,
                                       shape=new_shape)

        if opts.get('load_quantization'):
            G.quantization[NodeId(params)] = cls.load_tf_quantization(
                [node.input[0]], node.output)
        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))
        all_nodes[node.output[0]] = (params, 0, pnew_shape)
        return params
Пример #10
0
    def _common(cls, node: TFLiteNode, **kwargs):
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]
        x = inputs[0]
        x_shape = x[2].shape

        exp_dim = int(cls._verify_constant(inputs[1]))
        if exp_dim < 0:
            exp_dim += len(x_shape)
        if x_shape[exp_dim] is None:
            exp_dim += 1
        new_shape = x_shape[:exp_dim:] + [1] + x_shape[exp_dim::]

        if cls.is_constant(x):
            LOG.info("reducing %s to a constant", node.name)
            val = np.reshape(cls.get_constant(x), new_shape)
            params = ConstantInputParameters(node.name,
                                             value=val,
                                             dims=Dim.unnamed(val.shape))
        else:
            pnew_shape = ProvisionalDim(new_shape)
            old_shape = Dim.unnamed(cls.remove_unspecified_dim(x_shape),
                                    is_ordered=True)
            new_shape = Dim.unnamed(cls.remove_unspecified_dim(new_shape),
                                    is_ordered=True)
            params = ReshapeParameters(node.name,
                                       old_shape=old_shape,
                                       shape=new_shape)

        if opts.get('load_quantization'):
            G.quantization[NodeId(params)] = cls.load_tf_quantization(
                [node.input[0]], node.output)
        G.add_edge(
            NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))
        all_nodes[node.output[0]] = (params, 0, pnew_shape)
        return params
Пример #11
0
    def _common(cls, node: TFLiteNode, **kwargs):
        node_opts = node.get_options(StridedSliceOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]
        x = inputs[0]
        x_shape = x[2].shape

        # begin end stride
        vec_begin = list(cls._verify_constant(inputs[1]))
        vec_end = list(cls._verify_constant(inputs[2]))
        vec_stride = list(cls._verify_constant(inputs[3]))
        for i in range(1, 4):
            node.input[i].used = True
        if any([vec is None for vec in [vec_begin, vec_end, vec_stride]]):
            raise NotImplementedError(
                "strided slice with variable begin end or stride is not supported")
        spec = zip(vec_begin, vec_end, vec_stride)
        begin_mask = node_opts.BeginMask()
        ellipsis_mask = node_opts.EllipsisMask()
        end_mask = node_opts.EndMask()
        new_axis_mask = node_opts.NewAxisMask()
        shrink_axis_mask = node_opts.ShrinkAxisMask()

        act_slice, out_shape, can_reshape = StridedSliceParameters.get_slice(
            x_shape, spec,
            begin_mask,
            end_mask, ellipsis_mask,
            new_axis_mask, shrink_axis_mask)

        if cls.is_constant(x):
            LOG.info("reducing %s to a constant", node.name)
            x_val = cls.get_constant(x)
            params = StridedSliceParameters(node.name, act_slice=act_slice, out_shape=out_shape)
            x_val = params.numpy_slice(x_val)
            params = ConstantInputParameters(node.name, value=x_val)
        else:
            if can_reshape:
                if list(x_shape) == list(out_shape):
                    LOG.info("converting strided slice %s to a noop", node.name)
                    params = NoOPParameters(node.name)
                else:
                    LOG.info("converting strided slice %s to a reshape", node.name)
                    in_shape = Dim.unnamed(x[2].known_shape, is_ordered=True)
                    out_shape = Dim.unnamed(out_shape, is_ordered=True)
                    params = ReshapeParameters(node.name, old_shape=in_shape, shape=out_shape)
            else:
                params = StridedSliceParameters(node.name, act_slice=act_slice, out_shape=out_shape)
            G.add_edge(NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0))

        if opts.get('load_quantization'):
            G.quantization[NodeId(params)] = cls.load_tf_quantization([node.input[0]], node.output)
        all_nodes[node.output[0]] = (params, 0, x[2].infer_mapping(out_shape, allow_bad_length=True))
        return params
Пример #12
0
    def common_quantize(cls, in_qtype, out_qtype, node, **kwargs):
        all_nodes = kwargs['all_nodes']
        opts = kwargs['opts']
        G = kwargs['G']
        inputs = [all_nodes[t] for t in node.input]
        x = inputs[0]
        in_qtype = in_qtype.make_symmetric_signed()
        out_qtype = out_qtype.make_symmetric_signed()
        if cls.is_constant(x):
            LOG.info("reducing %s to a constant", node.name)
            if out_qtype:
                val = x[0].value_as(out_qtype)
            else:
                val = cls.get_constant(x)
            params = ConstantInputParameters(node.name,
                                             value=val,
                                             dims=Dim.unnamed(val.shape),
                                             qtype=out_qtype,
                                             constant_store=G.constant_store)
            if opts.get('load_quantization'):
                G.quantization[NodeId(params)] = QRec.scaled(
                    in_qs=[out_qtype], out_qs=[out_qtype])
        else:
            if in_qtype == out_qtype:
                LOG.info('removing (de)quantize node %s with no effect',
                         node.name)
                params = NoOPParameters(node.name,
                                        desc="quantize with no effect")
            elif in_qtype.dtype == out_qtype.dtype:
                LOG.info('removing (de)quantize node %s with scale change',
                         node.name)
                params = NoOPParameters(node.name,
                                        desc="quantize with scale change")
                out_qtype = in_qtype
            else:
                params = QuantizeParameters(node.name,
                                            from_qtype=in_qtype,
                                            to_qtype=out_qtype)
            G.add_edge(
                NNEdge(from_node=x[0], to_node=params, from_idx=x[1],
                       to_idx=0))

            if opts.get('load_quantization'):
                G.quantization[NodeId(params)] = QRec.scaled(
                    in_qs=[in_qtype], out_qs=[out_qtype])
        all_nodes[node.output[0]] = (params, 0, deepcopy(x[2]))
        return params
Пример #13
0
    def _common(cls, node: TFLiteNode, **kwargs):
        node_opts = node.get_options(PackOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]
        inp_shapes = [input[2].shape for input in inputs]

        values_count = node_opts.ValuesCount()
        check(
            len(inputs) == values_count,
            "invalid tflite file - values count not equal to inputs")

        buffer_idxes = [tensor.buffer_idx for tensor in node.input]
        check(
            len(set(buffer_idxes)) == len(buffer_idxes),
            "packs with multiple versions of the same input are not supported. This is normally a graph design problem."
        )

        axis = node_opts.Axis()
        dimension_size = len(inp_shapes)
        if axis < 0:
            axis += dimension_size

        check(all(shape == inp_shapes[0] for shape in inp_shapes[1::]),
              "invalid tflite file - pack inputs not the same")

        # prepare shapes of all tensors
        pconcat_out_shape = inp_shapes[0].copy()
        pconcat_out_shape.insert(axis, values_count)

        pconcat_in_shape = inp_shapes[0].copy()
        pconcat_in_shape.insert(axis, 1)

        preshape_in_shape = inp_shapes[0].copy()

        # remove nones from constants
        cls.remove_none_from_constants(inputs, preshape_in_shape)

        # remove nones from reshape shapes
        reshape_in_shape = cls.remove_unspecified_dim(preshape_in_shape)
        concat_in_shape = cls.remove_unspecified_dim(pconcat_in_shape)

        if all(cls.is_constant(inp) for inp in inputs):
            LOG.info("reducing %s to a constant", node.name)
            value = np.stack([cls.get_constant(inp) for inp in inputs],
                             axis=axis)
            params = ConstantInputParameters(node.name,
                                             value=value,
                                             constant_store=G.constant_store)
        else:
            axis -= sum(1 if dim is None else 0
                        for dim in pconcat_out_shape[:axis:])
            params = ConcatParameters(node.name, axis=axis, axis_hint=None)

            # insert reshapes on each input to add concat axis
            for idx, inp in enumerate(inputs):
                rparams = ReshapeParameters(node.name + "_%s" % idx,
                                            old_shape=reshape_in_shape,
                                            shape=concat_in_shape)
                G.add_edge(
                    NNEdge(from_node=inp[0], to_node=rparams, from_idx=inp[1]))
                G.add_edge(NNEdge(from_node=rparams, to_node=params))
                if opts.get('load_quantization'):
                    G.quantization[NodeId(rparams)] = cls.load_tf_quantization(
                        [node.input[idx]], [node.input[idx]])

        if opts.get('load_quantization'):
            G.quantization[NodeId(params)] = cls.load_tf_quantization(
                node.input, node.output)

        all_nodes[node.output[0]] = (params, 0,
                                     ProvisionalDim(pconcat_out_shape))
        return params
Пример #14
0
    def _common(cls, node: TFLiteNode, **kwargs):
        node_opts = node.get_options(ConcatenationOptions)
        G = kwargs['G']
        opts = kwargs['opts']
        all_nodes = kwargs['all_nodes']

        inputs = [all_nodes[t] for t in node.input]
        inp_shapes = [input[2].shape for input in inputs]

        buffer_idxes = [tensor.buffer_idx for tensor in node.input]
        non_zero_idxes = [idx for idx in buffer_idxes if idx != 0]
        duplicates = [
            idx for idx, count in Counter(non_zero_idxes).items() if count > 1
        ]
        if duplicates:
            LOG.warning(
                f'concat {node.name} has duplicate inputs. Inserting copies but this is not very efficient.'
            )
            for idx in duplicates:
                dup_idxes = [i for i, x in enumerate(buffer_idxes) if x == idx]
                for dup_idx in dup_idxes[1:]:
                    cparams = CopyParameters(
                        G.unique_name(
                            f'{node.name}_dup_{dup_idxes[0]}_{dup_idx}'))
                    dup_inp = inputs[dup_idx]
                    G.add_edge(
                        NNEdge(from_node=dup_inp[0],
                               from_idx=dup_inp[1],
                               to_node=cparams))
                    inputs[dup_idx] = tuple([cparams, 0] + list(dup_inp[2:]))

        axis = node_opts.Axis()
        if any(inp_shape[axis] is None for inp_shape in inp_shapes):
            raise ValueError("concat on undefined axis in node %s" % node.name)

        def red_func(x, y):
            return y.copy() if x is None else [
                (elem if y[idx] is not None and elem is not None else None)
                if idx != axis else elem + y[axis]
                for idx, elem in enumerate(x)
            ]

        pout_shape = reduce(red_func, inp_shapes)

        if all(cls.is_constant(inp) for inp in inputs):
            # cls.remove_none_from_constants(inputs, pout_shape)
            LOG.info("reducing %s to a constant", node.name)
            value = np.concatenate([cls.get_constant(inp) for inp in inputs],
                                   axis=axis)
            params = ConstantInputParameters(node.name, value=value)
        else:
            axis -= sum(1 if dim is None else 0 for dim in pout_shape[:axis:])
            params = ConcatParameters(node.name, axis=axis, axis_hint=None)

            for idx, inp in enumerate(inputs):
                inp_node, inp_idx = cls._maybe_insert_reshape(
                    G, inp, inp_shapes[idx], pout_shape)
                G.add_edge(
                    NNEdge(from_node=inp_node,
                           to_node=params,
                           from_idx=inp_idx,
                           to_idx=idx))
        if opts.get('load_quantization'):
            G.quantization[NodeId(params)] = cls.load_tf_quantization(
                node.input, node.output)
        cls.fuse_activation(node_opts, node.name, params, **kwargs)
        all_nodes[node.output[0]] = (params, 0, ProvisionalDim(pout_shape))
        return params