コード例 #1
0
ファイル: matmul.py プロジェクト: bergtholdt/ngraph-onnx
def reshape_for_matmul(onnx_node, input_a, input_b):
    # type: (NodeWrapper, NgraphNode, NgraphNode) -> Tuple[NgraphNode, NgraphNode]
    """Adjust input tensor shapes for matrix multiplication.

    This is based on an idea from onnx-tensorflow
    https://github.com/onnx/onnx-tensorflow/blob/17075f44c9071600beccfc62c92b22d1cd957bfd/onnx_tf/backend.py#L711
    They have hardcoded flatten input `A` before transposition.

    :param onnx_node: ONNX node for the matrix multiplication operation
    :param input_a: left side input node
    :param input_b: right side input node
    :return: tuple with input_a and input_b reshaped if needed
    """
    # First we check whether input data have incompatible shapes and then try flatten input data.
    if not has_matmul_compatible_shapes(input_a.shape, input_b.shape):
        input_a = flatten(input_a, 1)  # Flatten ND tensors to 2D matrices
        input_b = flatten(input_b, 1)
        if not has_matmul_compatible_shapes(input_a.shape, input_b.shape):
            raise ValueError('%s node (%s): input "A" and "B" data shapes are incompatible to '
                             'multiply with each other.', onnx_node.op_type, onnx_node.name)
    return input_a, input_b
コード例 #2
0
def Flatten(onnx_node,
            ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode
    """Flatten the input tensor into a 2D matrix.

    Flattening happens at axis specified by 'axis' attribute.
    First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor.
    The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n]
    """
    input_node = ng_inputs[0]
    axis = onnx_node.get_attribute_value('axis', 1)
    input_shape = list(input_node.shape)

    if axis < 0 or axis > len(input_shape):
        raise ValueError(
            'Flatten node (%s): %d is not a valid value for `axis`.',
            onnx_node.name, axis)

    return flatten(input_node, axis)
コード例 #3
0
def Size(onnx_node,
         ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode
    """Return input size."""
    # Dtype int64 is required for ONNX unit tests.
    return ng.constant(flatten(ng_inputs[0], 0).shape[1], dtype=np.int64)