def test_unsqueeze(self):
     self.assertEqual([1, 2, 3], infer.unsqueeze([2, 3], [0]))
     self.assertEqual([2, 3, 1], infer.unsqueeze([2, 3], [-1]))
     self.assertEqual([2, 1, 3], infer.unsqueeze([2, 3], [-2]))
     self.assertEqual([1, 2, 3], infer.unsqueeze([2, 3], [-3]))
     self.assertEqual([1, 2, 1], infer.unsqueeze([1, 2], [2]))
     self.assertEqual([2, 1, 3], infer.unsqueeze([2, 3], [1]))
     self.assertEqual([1, 1, 1, 2, 1, 1], infer.unsqueeze([1, 2], [0, 2, 4, 5]))
     self.assertEqual([1], infer.unsqueeze([], [0]))
示例#2
0
def nnef_unsqueeze(input, axes):
    return input.reshape(shape_inference.unsqueeze(input.shape, axes))
def propagate_expand_dims(op, const_value_by_tensor):
    # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]
    input, axis = op.inputs
    axis = const_value_by_tensor[axis].item()
    return [infer.unsqueeze(input=input.shape, axes=[axis])], [op.attribs['T']]
示例#4
0
def propagate_unsqueeze(op):
    # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]

    return [infer.unsqueeze(input=op.input.shape,
                            axes=op.attribs['axes'])], [op.input.dtype]
示例#5
0
def expand_dims_shape(op):
    # type: (Caffe2Operation)->ShapeResult
    return infer.unsqueeze(op.inputs[0].shape,
                           axes=op.attribs['dims']), op.inputs[0].dtype