示例#1
0
def reshape_shape(op):
    # type: (Caffe2Operation)->ShapeResult

    if len(op.inputs) == 1:
        shape = op.attribs['shape']
    elif len(op.inputs) == 2:
        if op.inputs[1].data is None:
            raise utils.NNEFToolsException(
                'Reshape is not supported with calculated shape.')
        shape = op.inputs[1].data.tolist()
    else:
        assert False

    graph_utils.replace_tensor_in_consumers(
        op.graph,
        op.outputs[1],
        Caffe2Tensor(graph=op.graph,
                     shape=[op.inputs[0].rank],
                     data=np.array(op.inputs[0].shape, dtype=np.int64),
                     dtype=DTYPE_INT64),
        remove=False)

    op.attribs['shape'] = shape
    op.inputs = (op.inputs[0], )

    return (infer.reshape(op.inputs[0].shape, shape=shape, zero_means_same=True), [op.inputs[0].rank]), \
           (op.inputs[0].dtype, DTYPE_INT64)
示例#2
0
def propagate_reshape(op):
    # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]

    if 'shape' in op.attribs:
        return [
            infer.reshape(input=op.input.shape,
                          shape=op.attribs['shape'],
                          zero_means_same=True)
        ], [op.input.dtype]
    else:
        input, shape = op.inputs

        return [
            infer.reshape(input=input.shape,
                          shape=evaluate_shape_tensor_simple(shape),
                          zero_means_same=True)
        ], [input.dtype]
示例#3
0
def nnef_reshape(input, shape, axis_start=0, axis_count=-1):
    # type: (torch.Tensor, List[int], int, int)->torch.Tensor

    return input.reshape(
        shape_inference.reshape(input=list(input.shape),
                                shape=shape,
                                offset=axis_start,
                                count=axis_count,
                                zero_means_same=True))
示例#4
0
def expand_softmax(tf_graph, tf_op):
    assert tf_op.input.rank != 0

    axis = tf_op.attribs.get('axis')
    if axis is None:
        axis = -1
    if axis < 0:
        axis += tf_op.input.rank

    tf_op.attribs['axis'] = -1

    if tf_op.input.rank == 2 and axis == 1:
        return

    if axis != tf_op.input.rank - 1:
        perm = utils.without(range(tf_op.input.rank), axis) + [axis]
        perm_inv = utils.inverse_permutation(perm)
        transpose = TFOperation(graph=tf_graph,
                                name="tf.transpose",
                                inputs=tf_op.input,
                                attribs=dict(perm=perm),
                                outputs=TFTensor(graph=tf_graph,
                                                 name=None,
                                                 shape=infer.transpose(
                                                     input=tf_op.input.shape,
                                                     axes=perm),
                                                 dtype=tf_op.input.dtype))
        tf_op.inputs = transpose.output
        old_output = tf_op.output
        tf_op.outputs = TFTensor(graph=tf_graph,
                                 name=None,
                                 shape=tf_op.input.shape,
                                 dtype=tf_op.input.dtype)
        TFOperation(graph=tf_graph,
                    name="tf.transpose",
                    inputs=tf_op.output,
                    attribs=dict(perm=perm_inv),
                    outputs=old_output)

    if tf_op.input.rank != 2:
        shape = [-1, tf_op.input.shape[-1]]
        reshape = TFOperation(graph=tf_graph,
                              name="tf.reshape",
                              inputs=tf_op.input,
                              attribs=dict(shape=shape),
                              outputs=TFTensor(graph=tf_graph,
                                               name=None,
                                               shape=infer.reshape(
                                                   input=tf_op.input.shape,
                                                   shape=shape),
                                               dtype=tf_op.input.dtype))
        tf_op.inputs = reshape.output
        old_output = tf_op.output
        tf_op.outputs = TFTensor(graph=tf_graph,
                                 name=None,
                                 shape=list(tf_op.input.shape),
                                 dtype=tf_op.input.dtype)
        TFOperation(graph=tf_graph,
                    name="tf.reshape",
                    inputs=tf_op.output,
                    attribs=dict(shape=old_output.shape),
                    outputs=old_output)
def propagate_reshape(op, const_value_by_tensor):
    # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]
    input, shape = op.inputs
    shape = const_value_by_tensor[shape].tolist()  # type: typing.List[int]
    return [infer.reshape(input=input.shape, shape=shape)], [op.attribs['T']]
 def test_reshape(self):
     self.assertEqual([4, 6], infer.reshape([4, 2, 3], [4, 6]))
     self.assertEqual([4, 6], infer.reshape([4, 2, 3], [4, -1]))
     self.assertEqual([24], infer.reshape([4, 2, 3], [24]))
     self.assertEqual([24], infer.reshape([4, 2, 3], [-1]))
     self.assertEqual([4, 2, 1, 3], infer.reshape([4, 2, 3], [4, 2, -1, 3]))
     self.assertEqual([4, 6, 1], infer.reshape([4, 2, 3], [4, -1, 1]))
     self.assertEqual([4, 6, 1], infer.reshape([4, 2, 3], [4, -1, 1]))
     self.assertEqual([1], infer.reshape([], [1]))
     self.assertEqual([1], infer.reshape([], [-1]))
     self.assertEqual([], infer.reshape([1], []))
     self.assertEqual([], infer.reshape([1, 1, 1], []))
     self.assertEqual([0, 1], infer.reshape([0], [0, 1]))
     self.assertEqual([1, 2, 3, 4], infer.reshape(input=[1, 3, 2, 4], shape=[2, 3], offset=1, count=2))
     self.assertEqual([1, 2, 3, 4], infer.reshape(input=[1, 24], shape=[2, 3, 4], offset=1, count=-1))
     with self.assertRaises(AssertionError):
         infer.reshape([0], [0, -1])
     self.assertEqual([1, 2, 1, 3], infer.reshape([1, 2, 3], [0, 0, 1, -1], zero_means_same=True))
     self.assertEqual([1], infer.reshape([1], [0], zero_means_same=True))
     with self.assertRaises(AssertionError):
         infer.reshape([], [0], zero_means_same=True)
     with self.assertRaises(AssertionError):
         infer.reshape([1], [1, 0], zero_means_same=True)
 def test_reshape(self):
     self.assertEqual([4, 6], infer.reshape([4, 2, 3], [4, 6]))
     self.assertEqual([4, 6], infer.reshape([4, 2, 3], [4, -1]))
     self.assertEqual([24], infer.reshape([4, 2, 3], [24]))
     self.assertEqual([24], infer.reshape([4, 2, 3], [-1]))
     self.assertEqual([4, 2, 1, 3], infer.reshape([4, 2, 3], [4, 2, -1, 3]))
     self.assertEqual([4, 6, 1], infer.reshape([4, 2, 3], [4, -1, 1]))
     self.assertEqual([4, 6, 1], infer.reshape([4, 2, 3], [4, -1, 1]))
     self.assertEqual([1], infer.reshape([], [1]))
     self.assertEqual([1], infer.reshape([], [-1]))
     self.assertEqual([], infer.reshape([1], []))
     self.assertEqual([], infer.reshape([1, 1, 1], []))
     self.assertEqual([0, 1], infer.reshape([0], [0, 1]))
     with self.assertRaises(AssertionError):
         infer.reshape([0], [0, -1])
     self.assertEqual([1, 2, 1, 3], infer.reshape([1, 2, 3], [0, 0, 1, -1], zero_means_same=True))
     self.assertEqual([1], infer.reshape([1], [0], zero_means_same=True))
     with self.assertRaises(AssertionError):
         infer.reshape([], [0], zero_means_same=True)
     with self.assertRaises(AssertionError):
         infer.reshape([1], [1, 0], zero_means_same=True)
示例#8
0
def reshape_shape(op):
    return shapes.reshape(op.inputs[0].shape,
                          shape=op.attribs['shape'],
                          offset=op.attribs['axis'],
                          count=op.attribs['num_axes'],
                          zero_means_same=True)