def propagate_transpose(op): # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] return [ infer.transpose(input=op.input.shape, axes=op.attribs.get('perm', None)) ], [op.input.dtype]
def transpose_shape(op): # type: (Caffe2Operation)->ShapeResult axes = op.attribs.get('axes') if not axes: axes = list(range(op.input.rank))[::-1] return infer.transpose(op.input.shape, axes), op.input.dtype
def expand_softmax(tf_graph, tf_op): assert tf_op.input.rank != 0 axis = tf_op.attribs.get('axis') if axis is None: axis = -1 if axis < 0: axis += tf_op.input.rank tf_op.attribs['axis'] = -1 if tf_op.input.rank == 2 and axis == 1: return if axis != tf_op.input.rank - 1: perm = utils.without(range(tf_op.input.rank), axis) + [axis] perm_inv = utils.inverse_permutation(perm) transpose = TFOperation(graph=tf_graph, name="tf.transpose", inputs=tf_op.input, attribs=dict(perm=perm), outputs=TFTensor(graph=tf_graph, name=None, shape=infer.transpose( input=tf_op.input.shape, axes=perm), dtype=tf_op.input.dtype)) tf_op.inputs = transpose.output old_output = tf_op.output tf_op.outputs = TFTensor(graph=tf_graph, name=None, shape=tf_op.input.shape, dtype=tf_op.input.dtype) TFOperation(graph=tf_graph, name="tf.transpose", inputs=tf_op.output, attribs=dict(perm=perm_inv), outputs=old_output) if tf_op.input.rank != 2: shape = [-1, tf_op.input.shape[-1]] reshape = TFOperation(graph=tf_graph, name="tf.reshape", inputs=tf_op.input, attribs=dict(shape=shape), outputs=TFTensor(graph=tf_graph, name=None, shape=infer.reshape( input=tf_op.input.shape, shape=shape), dtype=tf_op.input.dtype)) tf_op.inputs = reshape.output old_output = tf_op.output tf_op.outputs = TFTensor(graph=tf_graph, name=None, shape=list(tf_op.input.shape), dtype=tf_op.input.dtype) TFOperation(graph=tf_graph, name="tf.reshape", inputs=tf_op.output, attribs=dict(shape=old_output.shape), outputs=old_output)
def propagate_transpose(op, const_value_by_tensor): # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]] input, perm = op.inputs perm = const_value_by_tensor[perm].tolist() # type: typing.List[int] return [infer.transpose(input=input.shape, axes=perm)], [op.attribs['T']]
def test_transpose(self): self.assertEqual([], infer.transpose([])) self.assertEqual([1, 2, 3], infer.transpose([3, 2, 1])) self.assertEqual([10, 3, 32, 16], infer.transpose([10, 32, 16, 3], [0, 3, 1, 2]))