def slice_shape(op): # type: (Caffe2Operation)->ShapeResult # Currently, only slicing in a single dimension is supported in Caffe2 if len(op.inputs) == 1: starts = op.attribs['starts'] ends = op.attribs['ends'] elif len(op.inputs) == 3: if op.inputs[1].data is None: raise utils.NNEFToolsException( 'Slice is not supported with calculated sizes.') if op.inputs[2].data is None: raise utils.NNEFToolsException( 'Slice is not supported with calculated sizes.') starts = op.inputs[1].data.tolist() ends = op.inputs[2].data.tolist() else: assert False op.attribs = { 'starts': starts, 'ends': ends, } op.inputs = (op.inputs[0], ) return infer.slice(op.inputs[0].shape, begin=starts, end=[e + 1 if e < 0 else e for e in ends], zero_means_all=True), op.input.dtype
def propagate_slice(op): # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] data = op.inputs[0] if len(op.inputs) > 1: op.attribs['starts'] = evaluate_shape_tensor_simple(op.inputs[1]) if len(op.inputs) > 2: op.attribs['ends'] = evaluate_shape_tensor_simple(op.inputs[2]) if len(op.inputs) > 3: op.attribs['axes'] = evaluate_shape_tensor_simple(op.inputs[3]) if len(op.inputs) > 4: op.attribs['steps'] = evaluate_shape_tensor_simple(op.inputs[4]) starts = op.attribs['starts'] ends = op.attribs['ends'] axes = op.attribs.get('axes', list(range(len(starts)))) steps = op.attribs.get('steps') op.inputs = (data, ) return [ infer.slice(input=data.shape, axes=axes, begin=starts, end=ends, stride=steps) ], [data.dtype]
def propagate_dynamic_slice(op): # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] data, starts, ends, axes = op.inputs return [ infer.slice(input=data.shape, axes=evaluate_shape_tensor_simple(axes), begin=evaluate_shape_tensor_simple(starts), end=evaluate_shape_tensor_simple(ends)) ], [data.dtype]
def propagate_slice(op): # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] starts = op.attribs['starts'] ends = op.attribs['ends'] axes = op.attribs.get('axes', list(range(len(starts)))) return [ infer.slice(input=op.input.shape, axes=axes, begin=starts, end=ends) ], [op.input.dtype]
def propagate_slice(op, const_value_by_tensor): # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] input, begin, size = op.inputs begin = const_value_by_tensor[begin].tolist() # type: typing.List[int] size = const_value_by_tensor[size].tolist() # type: typing.List[int] return [ infer.slice(input=input.shape, begin=begin, size=size, zero_means_all=True) ], [op.attribs['T']]
def propagate_dynamic_slice(op): # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] data, starts, ends, axes = op.inputs op.attribs['starts'] = evaluate_shape_tensor_simple(starts) op.attribs['ends'] = evaluate_shape_tensor_simple(ends) op.attribs['axes'] = evaluate_shape_tensor_simple(axes) op.inputs = (data, ) return [ infer.slice(input=data.shape, axes=op.attribs['axes'], begin=op.attribs['starts'], end=op.attribs['ends']) ], [data.dtype]
def convert_slice(converter, nnef_op, caffe2_graph): # type: (Converter, NNEFOperation, Caffe2Graph)->None input = converter.converted_tensor(nnef_op.input) output = converter.converted_tensor(nnef_op.output) axes = [] starts = [] ends = [] for axis, begin, end in sorted(zip(nnef_op.attribs['axes'], nnef_op.attribs['begin'], nnef_op.attribs['end'])): assert axis >= 0 if begin < 0: begin += input.shape[axis] if end <= 0: end += input.shape[axis] if end == input.shape[axis]: end = -1 if (begin, end) != (0, -1): axes.append(axis) starts.append(begin) ends.append(end) if not axes: Caffe2Operation(graph=caffe2_graph, name='Copy', inputs=input, outputs=output) for i, (axis, start, end) in enumerate(zip(axes, starts, ends)): starts_attr = [0] * input.rank starts_attr[axis] = start ends_attr = [-1] * input.rank ends_attr[axis] = end tmp = output if i == len(axes) - 1 else Caffe2Tensor(graph=caffe2_graph, shape=infer.slice(input=input.shape, begin=[start], end=[0 if end == -1 else end], axes=[axis], zero_means_all=True), dtype=input.dtype) Caffe2Operation(graph=caffe2_graph, name='Slice', inputs=input, outputs=tmp, attribs=dict(starts=starts_attr, ends=ends_attr)) input = tmp
def test_slice(self): self.assertEqual([1, 1, 1, 2], infer.slice(input=[1, 2, 3, 4], begin=[0, 1, 2, 2], size=[1, 1, 1, 2])) self.assertEqual([1, 1, 1, 2], infer.slice(input=[1, 2, 3, 4], begin=[0, 1, 2, 2], size=[-1, -1, -1, -1])) self.assertEqual([1, 1, 1, 2], infer.slice(input=[1, 2, 3, 4], begin=[0, 1, 2, 2], size=[0, 0, 0, 0], zero_means_all=True)) self.assertEqual([0, 0, 0, 0], infer.slice([1, 2, 3, 4], begin=[0, 1, 2, 2], size=[0, 0, 0, 0])) self.assertEqual([2, 4, 6, 36], infer.slice(input=[10, 20, 30, 40], begin=[1, 2, 3, 4], size=[2, 4, 6, -1])) self.assertEqual([1, 1, 1, 2], infer.slice(input=[1, 2, 3, 4], begin=[0, 1, 2, 2], end=[1, 2, 3, 4])) self.assertEqual([1, 1, 1, 2], infer.slice(input=[1, 2, 3, 4], begin=[0, 1, 2, 2], end=[0, 0, 0, 0], zero_means_all=True)) self.assertEqual([0, 0, 0, 0], infer.slice([1, 2, 3, 4], begin=[0, 1, 2, 2], end=[0, 1, 2, 2])) self.assertEqual([2, 4, 6, 36], infer.slice(input=[10, 20, 30, 40], begin=[1, 2, 3, 4], end=[3, 6, 9, 0], zero_means_all=True)) self.assertEqual([10, 32, 32, 1], infer.slice(input=[10, 32, 32, 3], axes=[3], begin=[1], end=[2])) self.assertEqual([10, 32, 32, 0], infer.slice(input=[10, 32, 32, 3], axes=[3], begin=[1], end=[1])) self.assertEqual([10, 32, 32, 2], infer.slice(input=[10, 32, 32, 3], axes=[3], begin=[1], end=[0], zero_means_all=True)) self.assertEqual([10, 32, 32, 1], infer.slice(input=[10, 32, 32, 3], axes=[3], begin=[1], size=[1])) self.assertEqual([10, 32, 32, 2], infer.slice(input=[10, 32, 32, 3], axes=[3], begin=[1], size=[-1])) self.assertEqual([10, 32, 32, 0], infer.slice(input=[10, 32, 32, 3], axes=[-1], begin=[1], size=[0])) self.assertEqual([10, 32, 32, 2], infer.slice(input=[10, 32, 32, 3], axes=[-1], begin=[1], size=[0], zero_means_all=True)) self.assertEqual([1, 2, 1, 2], infer.slice(input=[10, 32, 32, 3], axes=[-1, 2, -3, 0], begin=[1, 2, 3, 0], end=[3, 3, 5, 1])) self.assertEqual([1, 2, 1, 2], infer.slice(input=[10, 32, 32, 3], axes=[-1, 2, -3, 0], begin=[1, 2, 3, 0], size=[-1, 1, 2, 1])) self.assertEqual([10, 5, 3, 2], infer.slice(input=[10, 20, 30, 40], begin=[0, 0, 0, 0], size=[10, 10, 10, 10], stride=[1, 2, 3, 4])) self.assertEqual([1, 14, 25, 35], infer.slice(input=[10, 20, 30, 40], begin=[5, 5, 5, 5], end=[6, -1, 30, 999]))