def test_split(self):
     self.assertEqual([[1, 2, 2], [1, 2, 2], [1, 2, 2]], infer.split([1, 2, 6], -1, num=3))
     self.assertEqual([[1, 2, 2], [1, 2, 4]], infer.split([1, 2, 6], -1, ratios=[1, 2]))
     self.assertEqual([[1, 2, 2], [1, 2, 4]], infer.split([1, 2, 6], -1, sizes=[2, 4]))
     self.assertEqual([[1, 2, 2], [1, 2, 4]], infer.split([1, 2, 6], -1, split_points=[2]))
     self.assertEqual([[1, 2, 2], [1, 2, 1], [1, 2, 2], [1, 2, 1]],
                      infer.split([1, 2, 6], -1, split_points=[2, 3, 5]))
Example #2
0
def propagate_split(op):
    # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]

    if 'split' in op.attribs:
        return (infer.split(input=op.input.shape,
                            axis=op.attribs.get('axis', 0),
                            sizes=op.attribs['split']),
                [op.input.dtype] * len(op.outputs))
    else:
        return (infer.split(input=op.input.shape,
                            axis=op.attribs.get('axis', 0),
                            num=len(op.outputs)),
                [op.input.dtype] * len(op.outputs))
def propagate_splitv(op, const_value_by_tensor):
    # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]
    input, size_splits, axis = op.inputs
    axis = const_value_by_tensor[axis].item()
    return (infer.split(input=input.shape,
                        axis=axis,
                        num=op.attribs['num_split']),
            [op.attribs['T']] * op.attribs['num_split'])
Example #4
0
def split_shape(op):
    # type: (Caffe2Operation)->ShapeResult
    if len(op.inputs) == 1:
        sizes = op.attribs['split']
    elif len(op.inputs) == 2:
        if op.inputs[1].data is None:
            raise utils.NNEFToolsException(
                'Split is not supported with calculated sizes.')
        sizes = op.inputs[1].data.tolist()
        op.attribs['split'] = sizes
    else:
        assert False

    op.inputs = (op.inputs[0], )

    output_shapes = tuple(
        infer.split(input=op.inputs[0].shape,
                    axis=op.attribs['axis'],
                    sizes=sizes))
    return output_shapes, (op.inputs[0].dtype, ) * len(output_shapes)
Example #5
0
def slice_shape(op):
    return shapes.split(op.inputs[0].shape,
                        axis=op.attribs['axis'],
                        split_points=op.attribs['slice_point'])