def reshape_for_matmul(onnx_node, input_a, input_b): # type: (NodeWrapper, NgraphNode, NgraphNode) -> Tuple[NgraphNode, NgraphNode] """Adjust input tensor shapes for matrix multiplication. This is based on an idea from onnx-tensorflow https://github.com/onnx/onnx-tensorflow/blob/17075f44c9071600beccfc62c92b22d1cd957bfd/onnx_tf/backend.py#L711 They have hardcoded flatten input `A` before transposition. :param onnx_node: ONNX node for the matrix multiplication operation :param input_a: left side input node :param input_b: right side input node :return: tuple with input_a and input_b reshaped if needed """ # First we check whether input data have incompatible shapes and then try flatten input data. if not has_matmul_compatible_shapes(input_a.shape, input_b.shape): input_a = flatten(input_a, 1) # Flatten ND tensors to 2D matrices input_b = flatten(input_b, 1) if not has_matmul_compatible_shapes(input_a.shape, input_b.shape): raise ValueError('%s node (%s): input "A" and "B" data shapes are incompatible to ' 'multiply with each other.', onnx_node.op_type, onnx_node.name) return input_a, input_b
def Flatten(onnx_node, ng_inputs): # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode """Flatten the input tensor into a 2D matrix. Flattening happens at axis specified by 'axis' attribute. First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor. The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n] """ input_node = ng_inputs[0] axis = onnx_node.get_attribute_value('axis', 1) input_shape = list(input_node.shape) if axis < 0 or axis > len(input_shape): raise ValueError( 'Flatten node (%s): %d is not a valid value for `axis`.', onnx_node.name, axis) return flatten(input_node, axis)
def Size(onnx_node, ng_inputs): # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode """Return input size.""" # Dtype int64 is required for ONNX unit tests. return ng.constant(flatten(ng_inputs[0], 0).shape[1], dtype=np.int64)