Example #1
0
def Gemm(onnx_node, ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode
    """Calculate general matrix multiplication Y = alpha * (A @ B) + beta * C."""
    input_a, input_b, input_c = ng_inputs
    alpha = onnx_node.get_attribute_value('alpha', 1)  # Scalar multiplier for A @ B
    beta = onnx_node.get_attribute_value('beta', 1)  # Scalar multiplier for input tensor C
    broadcast = onnx_node.get_attribute_value('broadcast', 1)  # Should C be broadcast?
    trans_a = onnx_node.get_attribute_value('transA', False)  # Should A be transposed?
    trans_b = onnx_node.get_attribute_value('transB', False)  # Should B be transposed?

    if trans_a:
        input_a = transpose(input_a)
    if trans_b:
        input_b = transpose(input_b)

    # onnx-tensorflow: https://github.com/onnx/onnx-tensorflow/
    #  blob/17075f44c9071600beccfc62c92b22d1cd957bfd/onnx_tf/backend.py#L711
    # They have hardcoded flatten input `A` before transposition.
    #
    # Firstly, we check whether input data have incompatible shapes and then try flatten input data.
    if not has_matmul_compatible_shapes(input_a.shape, input_b.shape):
        input_a = flatten_innermost_empty_dims(input_a)
        input_b = flatten_innermost_empty_dims(input_b)
        if not has_matmul_compatible_shapes(input_a.shape, input_b.shape):
            raise ValueError('Gemm node (%s): input "A" and "B" data shapes are incompatible to '
                             'multiply with each other.', onnx_node.name)

    a_dot_b = ng.dot(input_a, input_b)

    if not broadcast and input_c.shape != a_dot_b.shape:
        raise ValueError('Gemm node (%s): input data shapes are incompatible and broadcast '
                         ' was not requested!', onnx_node.name)

    return alpha * a_dot_b + beta * input_c
Example #2
0
def Gemm(onnx_node,
         ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode
    """Calculate general matrix multiplication Y = alpha * (A @ B) + beta * C.

    Support is currently limited to 2D matrices only. Higher dimensional tensors will
    be flattened to 2D before multiplication.
    """
    input_a, input_b, input_c = ng_inputs
    alpha = onnx_node.get_attribute_value('alpha',
                                          1)  # Scalar multiplier for A @ B
    beta = onnx_node.get_attribute_value(
        'beta', 1)  # Scalar multiplier for input tensor C
    trans_a = onnx_node.get_attribute_value('transA',
                                            False)  # Should A be transposed?
    trans_b = onnx_node.get_attribute_value('transB',
                                            False)  # Should B be transposed?

    if trans_a:
        input_a = transpose(input_a)
    if trans_b:
        input_b = transpose(input_b)

    input_a, input_b = reshape_for_matmul(onnx_node, input_a, input_b)

    a_dot_b = ng.dot(input_a, input_b)

    if alpha != 1:
        a_dot_b = alpha * a_dot_b

    if beta != 1:
        input_c = beta * input_c

    _, input_c = numpy_style_broadcast_for_binary_operation(
        onnx_node, [a_dot_b, input_c])
    return a_dot_b + input_c
Example #3
0
def Transpose(onnx_node, ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode
    """Transpose the input tensor similar to numpy.transpose.

    By default, reverse the dimensions, but if `perm` attribute is specified
    permute the axes according to the values given.
    """
    input_node = ng_inputs[0]
    permute_axes = onnx_node.get_attribute_value('perm')
    if permute_axes is None:
        return transpose(input_node)
    else:
        return reorder_axes(input_node, permute_axes)