def test_split(self): self.assertEqual([[1, 2, 2], [1, 2, 2], [1, 2, 2]], infer.split([1, 2, 6], -1, num=3)) self.assertEqual([[1, 2, 2], [1, 2, 4]], infer.split([1, 2, 6], -1, ratios=[1, 2])) self.assertEqual([[1, 2, 2], [1, 2, 4]], infer.split([1, 2, 6], -1, sizes=[2, 4])) self.assertEqual([[1, 2, 2], [1, 2, 4]], infer.split([1, 2, 6], -1, split_points=[2])) self.assertEqual([[1, 2, 2], [1, 2, 1], [1, 2, 2], [1, 2, 1]], infer.split([1, 2, 6], -1, split_points=[2, 3, 5]))
def propagate_split(op): # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] if 'split' in op.attribs: return (infer.split(input=op.input.shape, axis=op.attribs.get('axis', 0), sizes=op.attribs['split']), [op.input.dtype] * len(op.outputs)) else: return (infer.split(input=op.input.shape, axis=op.attribs.get('axis', 0), num=len(op.outputs)), [op.input.dtype] * len(op.outputs))
def propagate_splitv(op, const_value_by_tensor): # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] input, size_splits, axis = op.inputs axis = const_value_by_tensor[axis].item() return (infer.split(input=input.shape, axis=axis, num=op.attribs['num_split']), [op.attribs['T']] * op.attribs['num_split'])
def split_shape(op): # type: (Caffe2Operation)->ShapeResult if len(op.inputs) == 1: sizes = op.attribs['split'] elif len(op.inputs) == 2: if op.inputs[1].data is None: raise utils.NNEFToolsException( 'Split is not supported with calculated sizes.') sizes = op.inputs[1].data.tolist() op.attribs['split'] = sizes else: assert False op.inputs = (op.inputs[0], ) output_shapes = tuple( infer.split(input=op.inputs[0].shape, axis=op.attribs['axis'], sizes=sizes)) return output_shapes, (op.inputs[0].dtype, ) * len(output_shapes)
def slice_shape(op): return shapes.split(op.inputs[0].shape, axis=op.attribs['axis'], split_points=op.attribs['slice_point'])