Exemplo n.º 1
0
def propagate_broadcast(op, const_value_by_tensor, dtype=''):
    # type: (TFOperation, _ConstValueByTensorT, str)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]

    return ([
        infer.elementwise(inputs=[input.shape for input in op.inputs],
                          broadcast=infer.Broadcast.FROM_RIGHT)
    ], [dtype if dtype else get_op_t(op)])
Exemplo n.º 2
0
def propagate_broadcast(op, dtype_from_index=0):
    # type: (ONNXOperation, int)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]

    return ([
        infer.elementwise(inputs=[input.shape for input in op.inputs],
                          broadcast=infer.Broadcast.FROM_RIGHT)
    ], [op.inputs[dtype_from_index].dtype])
Exemplo n.º 3
0
def propagate_broadcast_with_axis(op, dtype=None):
    # type: (ONNXOperation, typing.Optional[str])->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]

    x, y = op.inputs

    if 'axis' in op.attribs:
        axis = op.attribs['axis']
        if axis < 0:
            axis += x.rank
        return ([
            infer.elementwise(inputs=[x.shape, axis * [1] + y.shape],
                              broadcast=infer.Broadcast.FROM_LEFT)
        ], [x.dtype])
    else:
        return ([
            infer.elementwise(inputs=[x.shape, y.shape],
                              broadcast=infer.Broadcast.FROM_RIGHT)
        ], [x.dtype if dtype is None else dtype])
Exemplo n.º 4
0
def propagate_expand(op):
    # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]
    input, shape = op.inputs

    return [
        infer.elementwise(
            inputs=[input.shape,
                    evaluate_shape_tensor_simple(shape)],
            broadcast=infer.Broadcast.FROM_RIGHT)
    ], [input.dtype]
Exemplo n.º 5
0
def propagate_expand(op):
    # type: (ONNXOperation)->typing.Tuple[typing.List[typing.List[int]], typing.List[str]]
    input, shape = op.inputs

    input_shape = list(input.shape)
    other_shape = evaluate_shape_tensor_simple(shape)

    if len(other_shape) > len(input_shape):
        input_shape = [1] * (len(other_shape) - len(input_shape)) + input_shape
    if len(input_shape) > len(other_shape):
        other_shape = [1] * (len(input_shape) - len(other_shape)) + other_shape

    # undocumented feature in ONNX, probably comes from pytorch
    other_shape = [
        i if o == -1 else o for i, o in zip(input_shape, other_shape)
    ]

    op.attribs['shape'] = other_shape
    op.inputs = (input, )

    return [
        infer.elementwise(inputs=[input_shape, other_shape],
                          broadcast=infer.Broadcast.SAME_RANK)
    ], [input.dtype]
Exemplo n.º 6
0
    def test_elementwise(self):
        a = [1, 2, 3]
        b = [1, 2, 3]
        c = infer.elementwise([a, b], infer.Broadcast.NONE)
        self.assertEqual(c, a)
        self.assertIsNot(c, a)
        self.assertIsNot(c, b)

        a = [1, 2, 3]
        b = [3]
        ab = [a, b]
        c = infer.elementwise([a, b], infer.Broadcast.FROM_RIGHT)
        self.assertEqual([1, 2, 3], c)
        self.assertEqual([1, 2, 3], a)
        self.assertEqual([3], b)
        self.assertEqual([a, b], ab)

        with self.assertRaises(AssertionError):
            infer.elementwise([[1, 2], [1]], infer.Broadcast.NONE)
        with self.assertRaises(AssertionError):
            infer.elementwise([[1, 2], [1, 3]], infer.Broadcast.NONE)

        self.assertEqual([2, 4, 3], infer.elementwise([[1, 4, 1], [2, 1, 3]], infer.Broadcast.SAME_RANK))
        with self.assertRaises(AssertionError):
            infer.elementwise([[1, 2], [1]], infer.Broadcast.SAME_RANK)

        self.assertEqual([4, 2, 3], infer.elementwise([[1, 2, 3], [4, 2]], infer.Broadcast.FROM_LEFT))
        with self.assertRaises(AssertionError):
            infer.elementwise([[2, 3], [3]], infer.Broadcast.FROM_LEFT)

        self.assertEqual([1, 2, 3], infer.elementwise([[1, 1, 3], [2, 3]], infer.Broadcast.FROM_RIGHT))
        with self.assertRaises(AssertionError):
            infer.elementwise([[2, 3], [2]], infer.Broadcast.FROM_RIGHT)