def tf_matmul_infer(node): assert (len(node.in_nodes()) == 2) shapes = [node.in_node(i).shape.copy() for i in range(2)] log.debug('matmul shapes: {}'.format(shapes)) if any(s is None or len(s) < 2 for s in shapes): log.error("MatMul wasn't able to infer shape") return if node.transpose_a: perm = np.array(range(len(node.in_node(0).shape)), dtype=np.int64) perm[-1], perm[-2] = perm[-2], perm[-1] inv = PermuteAttrs.get_inverse_permutation(perm) permutation = PermuteAttrs.Permutation(perm=perm, inv=int64_array(inv)) PermuteAttrs.set_permutation(node.in_node(0), node, permutation) shapes[0] = shapes[0][perm] if node.transpose_b: perm = np.array(range(len(node.in_node(1).shape)), dtype=np.int64) perm[-1], perm[-2] = perm[-2], perm[-1] inv = PermuteAttrs.get_inverse_permutation(perm) permutation = PermuteAttrs.Permutation(perm=perm, inv=int64_array(inv)) PermuteAttrs.set_permutation(node.in_node(1), node, permutation) shapes[1] = shapes[1][perm] if any(shapes[0][:-2] != shapes[1][:-2]) or shapes[0][-1] != shapes[1][-2]: log.error( "MatMul wasn't able to infer shape because input dimensions are not compatible" ) return shape_tuple = (np.array(shapes[0][:-1], dtype=np.int64), np.array([shapes[1][-1]], dtype=np.int64)) if len(shapes[0]) > 2: # TODO Investigate case when MatMul have inputs with not matching output dimensions # It looks to be a practical case and if we add outer dimensions of the first argument # it will lead to incorrect model sometimes. TF documentation is unclear. log.warning( 'Ignored outer dimensions of input tensor for MatMul node: {}'. format(node.name)) # shape_tuple = (shapes[0][:-2], *shape_tuple) log.debug('shape_tuple: {}'.format(shape_tuple)) node.out_node().shape = np.concatenate(shape_tuple) node['channel_dims'] = node.out_node().shape.size - 1 log.debug('matmul shape: {}'.format(node.out_node().shape))
def reverse_permute(output_shape: np.array, order: np.array): """ Calculates Transpose op input shape based on output shape and permute order. :param output_shape: Transpose output shape :param order: permute order :return: Transpose input shape corresponding to the specified output shape """ return int64_array(output_shape[PermuteAttrs.get_inverse_permutation(order)])