def reshape_shape(op): # type: (Caffe2Operation)->ShapeResult if len(op.inputs) == 1: shape = op.attribs['shape'] elif len(op.inputs) == 2: if op.inputs[1].data is None: raise utils.NNEFToolsException( 'Reshape is not supported with calculated shape.') shape = op.inputs[1].data.tolist() else: assert False graph_utils.replace_tensor_in_consumers( op.graph, op.outputs[1], Caffe2Tensor(graph=op.graph, shape=[op.inputs[0].rank], data=np.array(op.inputs[0].shape, dtype=np.int64), dtype=DTYPE_INT64), remove=False) op.attribs['shape'] = shape op.inputs = (op.inputs[0], ) return (infer.reshape(op.inputs[0].shape, shape=shape, zero_means_same=True), [op.inputs[0].rank]), \ (op.inputs[0].dtype, DTYPE_INT64)
def propagate_reshape(op): # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] if 'shape' in op.attribs: return [ infer.reshape(input=op.input.shape, shape=op.attribs['shape'], zero_means_same=True) ], [op.input.dtype] else: input, shape = op.inputs return [ infer.reshape(input=input.shape, shape=evaluate_shape_tensor_simple(shape), zero_means_same=True) ], [input.dtype]
def nnef_reshape(input, shape, axis_start=0, axis_count=-1): # type: (torch.Tensor, List[int], int, int)->torch.Tensor return input.reshape( shape_inference.reshape(input=list(input.shape), shape=shape, offset=axis_start, count=axis_count, zero_means_same=True))
def expand_softmax(tf_graph, tf_op): assert tf_op.input.rank != 0 axis = tf_op.attribs.get('axis') if axis is None: axis = -1 if axis < 0: axis += tf_op.input.rank tf_op.attribs['axis'] = -1 if tf_op.input.rank == 2 and axis == 1: return if axis != tf_op.input.rank - 1: perm = utils.without(range(tf_op.input.rank), axis) + [axis] perm_inv = utils.inverse_permutation(perm) transpose = TFOperation(graph=tf_graph, name="tf.transpose", inputs=tf_op.input, attribs=dict(perm=perm), outputs=TFTensor(graph=tf_graph, name=None, shape=infer.transpose( input=tf_op.input.shape, axes=perm), dtype=tf_op.input.dtype)) tf_op.inputs = transpose.output old_output = tf_op.output tf_op.outputs = TFTensor(graph=tf_graph, name=None, shape=tf_op.input.shape, dtype=tf_op.input.dtype) TFOperation(graph=tf_graph, name="tf.transpose", inputs=tf_op.output, attribs=dict(perm=perm_inv), outputs=old_output) if tf_op.input.rank != 2: shape = [-1, tf_op.input.shape[-1]] reshape = TFOperation(graph=tf_graph, name="tf.reshape", inputs=tf_op.input, attribs=dict(shape=shape), outputs=TFTensor(graph=tf_graph, name=None, shape=infer.reshape( input=tf_op.input.shape, shape=shape), dtype=tf_op.input.dtype)) tf_op.inputs = reshape.output old_output = tf_op.output tf_op.outputs = TFTensor(graph=tf_graph, name=None, shape=list(tf_op.input.shape), dtype=tf_op.input.dtype) TFOperation(graph=tf_graph, name="tf.reshape", inputs=tf_op.output, attribs=dict(shape=old_output.shape), outputs=old_output)
def propagate_reshape(op, const_value_by_tensor): # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] input, shape = op.inputs shape = const_value_by_tensor[shape].tolist() # type: typing.List[int] return [infer.reshape(input=input.shape, shape=shape)], [op.attribs['T']]
def test_reshape(self): self.assertEqual([4, 6], infer.reshape([4, 2, 3], [4, 6])) self.assertEqual([4, 6], infer.reshape([4, 2, 3], [4, -1])) self.assertEqual([24], infer.reshape([4, 2, 3], [24])) self.assertEqual([24], infer.reshape([4, 2, 3], [-1])) self.assertEqual([4, 2, 1, 3], infer.reshape([4, 2, 3], [4, 2, -1, 3])) self.assertEqual([4, 6, 1], infer.reshape([4, 2, 3], [4, -1, 1])) self.assertEqual([4, 6, 1], infer.reshape([4, 2, 3], [4, -1, 1])) self.assertEqual([1], infer.reshape([], [1])) self.assertEqual([1], infer.reshape([], [-1])) self.assertEqual([], infer.reshape([1], [])) self.assertEqual([], infer.reshape([1, 1, 1], [])) self.assertEqual([0, 1], infer.reshape([0], [0, 1])) self.assertEqual([1, 2, 3, 4], infer.reshape(input=[1, 3, 2, 4], shape=[2, 3], offset=1, count=2)) self.assertEqual([1, 2, 3, 4], infer.reshape(input=[1, 24], shape=[2, 3, 4], offset=1, count=-1)) with self.assertRaises(AssertionError): infer.reshape([0], [0, -1]) self.assertEqual([1, 2, 1, 3], infer.reshape([1, 2, 3], [0, 0, 1, -1], zero_means_same=True)) self.assertEqual([1], infer.reshape([1], [0], zero_means_same=True)) with self.assertRaises(AssertionError): infer.reshape([], [0], zero_means_same=True) with self.assertRaises(AssertionError): infer.reshape([1], [1, 0], zero_means_same=True)
def test_reshape(self): self.assertEqual([4, 6], infer.reshape([4, 2, 3], [4, 6])) self.assertEqual([4, 6], infer.reshape([4, 2, 3], [4, -1])) self.assertEqual([24], infer.reshape([4, 2, 3], [24])) self.assertEqual([24], infer.reshape([4, 2, 3], [-1])) self.assertEqual([4, 2, 1, 3], infer.reshape([4, 2, 3], [4, 2, -1, 3])) self.assertEqual([4, 6, 1], infer.reshape([4, 2, 3], [4, -1, 1])) self.assertEqual([4, 6, 1], infer.reshape([4, 2, 3], [4, -1, 1])) self.assertEqual([1], infer.reshape([], [1])) self.assertEqual([1], infer.reshape([], [-1])) self.assertEqual([], infer.reshape([1], [])) self.assertEqual([], infer.reshape([1, 1, 1], [])) self.assertEqual([0, 1], infer.reshape([0], [0, 1])) with self.assertRaises(AssertionError): infer.reshape([0], [0, -1]) self.assertEqual([1, 2, 1, 3], infer.reshape([1, 2, 3], [0, 0, 1, -1], zero_means_same=True)) self.assertEqual([1], infer.reshape([1], [0], zero_means_same=True)) with self.assertRaises(AssertionError): infer.reshape([], [0], zero_means_same=True) with self.assertRaises(AssertionError): infer.reshape([1], [1, 0], zero_means_same=True)
def reshape_shape(op): return shapes.reshape(op.inputs[0].shape, shape=op.attribs['shape'], offset=op.attribs['axis'], count=op.attribs['num_axes'], zero_means_same=True)