Esempio n. 1
0
def propagate_transpose(op):
    # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]

    return [
        infer.transpose(input=op.input.shape,
                        axes=op.attribs.get('perm', None))
    ], [op.input.dtype]
Esempio n. 2
0
def transpose_shape(op):
    # type: (Caffe2Operation)->ShapeResult

    axes = op.attribs.get('axes')
    if not axes:
        axes = list(range(op.input.rank))[::-1]

    return infer.transpose(op.input.shape, axes), op.input.dtype
Esempio n. 3
0
def expand_softmax(tf_graph, tf_op):
    assert tf_op.input.rank != 0

    axis = tf_op.attribs.get('axis')
    if axis is None:
        axis = -1
    if axis < 0:
        axis += tf_op.input.rank

    tf_op.attribs['axis'] = -1

    if tf_op.input.rank == 2 and axis == 1:
        return

    if axis != tf_op.input.rank - 1:
        perm = utils.without(range(tf_op.input.rank), axis) + [axis]
        perm_inv = utils.inverse_permutation(perm)
        transpose = TFOperation(graph=tf_graph,
                                name="tf.transpose",
                                inputs=tf_op.input,
                                attribs=dict(perm=perm),
                                outputs=TFTensor(graph=tf_graph,
                                                 name=None,
                                                 shape=infer.transpose(
                                                     input=tf_op.input.shape,
                                                     axes=perm),
                                                 dtype=tf_op.input.dtype))
        tf_op.inputs = transpose.output
        old_output = tf_op.output
        tf_op.outputs = TFTensor(graph=tf_graph,
                                 name=None,
                                 shape=tf_op.input.shape,
                                 dtype=tf_op.input.dtype)
        TFOperation(graph=tf_graph,
                    name="tf.transpose",
                    inputs=tf_op.output,
                    attribs=dict(perm=perm_inv),
                    outputs=old_output)

    if tf_op.input.rank != 2:
        shape = [-1, tf_op.input.shape[-1]]
        reshape = TFOperation(graph=tf_graph,
                              name="tf.reshape",
                              inputs=tf_op.input,
                              attribs=dict(shape=shape),
                              outputs=TFTensor(graph=tf_graph,
                                               name=None,
                                               shape=infer.reshape(
                                                   input=tf_op.input.shape,
                                                   shape=shape),
                                               dtype=tf_op.input.dtype))
        tf_op.inputs = reshape.output
        old_output = tf_op.output
        tf_op.outputs = TFTensor(graph=tf_graph,
                                 name=None,
                                 shape=list(tf_op.input.shape),
                                 dtype=tf_op.input.dtype)
        TFOperation(graph=tf_graph,
                    name="tf.reshape",
                    inputs=tf_op.output,
                    attribs=dict(shape=old_output.shape),
                    outputs=old_output)
def propagate_transpose(op, const_value_by_tensor):
    # type: (TFOperation, _ConstValueByTensorT)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]
    input, perm = op.inputs
    perm = const_value_by_tensor[perm].tolist()  # type: typing.List[int]
    return [infer.transpose(input=input.shape, axes=perm)], [op.attribs['T']]
Esempio n. 5
0
 def test_transpose(self):
     self.assertEqual([], infer.transpose([]))
     self.assertEqual([1, 2, 3], infer.transpose([3, 2, 1]))
     self.assertEqual([10, 3, 32, 16], infer.transpose([10, 32, 16, 3], [0, 3, 1, 2]))