def bitonic_woven_matrices(n):
    """
    Combine the l,r and l_inv, r_inv matrices into single n x n multiplies, for
    use with bisort_weave/diff_bisort_weave, fusing together consecutive stages.
    This reduces the number of multiplies to (k)(k+1) + 1 multiplies, where k=np.log2(n)    
    """
    layers = int(np.log2(n))
    matrices = []
    last_unweave = np.eye(n)
    for n, m, layer in bitonic_layer_loop(n):
        weave, unweave = np.zeros((n, n)), np.zeros((n, n))
        for a, b, out, swap in bitonic_swap_loop(n, m, layer):
            weave[out, a] = 1
            weave[out + n // 2, b] = 1
            # flip comparison order as needed
            if swap:
                a, b = b, a
            unweave[a, out] = 1
            unweave[b, out + n // 2] = 1
        # fuse the unweave and weave steps
        matrices.append(weave @ last_unweave)
        last_unweave = unweave
    # make sure the last unweave is preserved
    matrices.append(last_unweave)
    return matrices
def bitonic_woven_matrices_alt(n):
    """
    Alternative direct implementation of bitonic_woven_matrices. 
    """
    layers = int(np.log2(n))
    matrices = []
    n2 = n // 2
    last_unweave = np.eye(n)
    for layer in range(layers):
        for s in range(layer + 1):
            m = 1 << (layer - s)
            weave, unweave = np.zeros((n, n)), np.zeros((n, n))
            out = 0
            for i in range(0, n, m << 1):
                for j in range(m):
                    ix = i + j
                    a, b = ix, ix + m
                    weave[out, a] = 1
                    weave[out + n // 2, b] = 1
                    if (ix >> (layer + 1)) & 1:
                        a, b = b, a
                    unweave[a, out] = 1
                    unweave[b, out + n // 2] = 1
                    out += 1
            matrices.append(weave @ last_unweave)
            last_unweave = unweave
    matrices.append(last_unweave)
    return matrices
def diff_sort_weave(fused, x, softmax=softmax, beta=0.0):
    """
    Given a set of bitonic sort matrices generated by bitonic_woven_matrices(n), sort 
    a sequence x of length n.
    beta specifies interpolation between true permutations (beta=0.0) and
    leaving the values unchanged (beta=1.0)
    """
    i = np.eye(len(x))
    split = len(x) // 2
    x = ((beta * i) + (1 - beta) * fused[0]) @ x
    for mat in fused[1:]:
        a, b = x[:split], x[split:]
        mx = softmax(a, b)
        mn = a + b - mx
        x = (beta * i + (1 - beta) * mat) @ np.concatenate([mn, mx])
    return x