def bitonic_woven_matrices(n): """ Combine the l,r and l_inv, r_inv matrices into single n x n multiplies, for use with bisort_weave/diff_bisort_weave, fusing together consecutive stages. This reduces the number of multiplies to (k)(k+1) + 1 multiplies, where k=np.log2(n) """ layers = int(np.log2(n)) matrices = [] last_unweave = np.eye(n) for n, m, layer in bitonic_layer_loop(n): weave, unweave = np.zeros((n, n)), np.zeros((n, n)) for a, b, out, swap in bitonic_swap_loop(n, m, layer): weave[out, a] = 1 weave[out + n // 2, b] = 1 # flip comparison order as needed if swap: a, b = b, a unweave[a, out] = 1 unweave[b, out + n // 2] = 1 # fuse the unweave and weave steps matrices.append(weave @ last_unweave) last_unweave = unweave # make sure the last unweave is preserved matrices.append(last_unweave) return matrices
def bitonic_woven_matrices_alt(n): """ Alternative direct implementation of bitonic_woven_matrices. """ layers = int(np.log2(n)) matrices = [] n2 = n // 2 last_unweave = np.eye(n) for layer in range(layers): for s in range(layer + 1): m = 1 << (layer - s) weave, unweave = np.zeros((n, n)), np.zeros((n, n)) out = 0 for i in range(0, n, m << 1): for j in range(m): ix = i + j a, b = ix, ix + m weave[out, a] = 1 weave[out + n // 2, b] = 1 if (ix >> (layer + 1)) & 1: a, b = b, a unweave[a, out] = 1 unweave[b, out + n // 2] = 1 out += 1 matrices.append(weave @ last_unweave) last_unweave = unweave matrices.append(last_unweave) return matrices
def diff_sort_weave(fused, x, softmax=softmax, beta=0.0): """ Given a set of bitonic sort matrices generated by bitonic_woven_matrices(n), sort a sequence x of length n. beta specifies interpolation between true permutations (beta=0.0) and leaving the values unchanged (beta=1.0) """ i = np.eye(len(x)) split = len(x) // 2 x = ((beta * i) + (1 - beta) * fused[0]) @ x for mat in fused[1:]: a, b = x[:split], x[split:] mx = softmax(a, b) mn = a + b - mx x = (beta * i + (1 - beta) * mat) @ np.concatenate([mn, mx]) return x