def test_squeeze(self):
        self.assertEqual([3, 2], infer.squeeze([3, 2, 1], [-1]))
        self.assertEqual([2, 3], infer.squeeze([1, 2, 3], [0]))
        self.assertEqual([2, 2], infer.squeeze([2, 1, 2, 1], [1, -1]))
        self.assertEqual([], infer.squeeze([1, 1], [0, -1]))

        self.assertEqual([3, 2], infer.squeeze([1, 3, 1, 1, 2, 1, 1]))
示例#2
0
def nnef_squeeze(input, axes):
    return input.reshape(shape_inference.squeeze(input.shape, axes))
def propagate_squeeze(op, const_value_by_tensor):
    # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]

    axes = op.attribs['squeeze_dims']
    return [infer.squeeze(input=op.input.shape, axes=axes if axes else None)], [op.attribs['T']]
示例#4
0
def propagate_squeeze(op):
    # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]

    return [
        infer.squeeze(input=op.input.shape, axes=op.attribs.get('axes', None))
    ], [op.input.dtype]
示例#5
0
def squeeze_shape(op):
    # type: (Caffe2Operation)->ShapeResult
    return infer.squeeze(op.inputs[0].shape,
                         axes=op.attribs['dims']), op.inputs[0].dtype