def rand_span(A): while 1: m, n = A.shape v = rand2(m, m) A1 = dot2(v, A) assert A1.shape == A.shape if rank(A) == rank(A1): break assert rank(intersect(A, A1)) == rank(A) return A1
def is_correctable(idxs, n, LxiHx, LziHz, Hx, Hz, Lx, Lz, **kw): #print("len(idxs) =", len(idxs)) #Ax = in_support(LxiHx, idxs) #print(Ax.shape) A = identity2(n)[idxs] Ax = intersect(LxiHx, A) Az = intersect(LziHz, A) assert dot2(Ax, Hz.transpose()).sum() == 0 assert dot2(Az, Hx.transpose()).sum() == 0 if dot2(Ax, Lz.transpose()).sum() == 0 and dot2(Az, Lx.transpose()).sum() == 0: return True if 0: #draw.mark_zop(Az[0]) print("Ax:") print(shortstr(Ax)) print("Az:") print(shortstr(Az)) return False
def in_support(H, keep_idxs, check=False): # copied from classical.py # find span of H contained within idxs support n = H.shape[1] remove_idxs = [i for i in range(n) if i not in keep_idxs] A = identity2(n) A = A[keep_idxs] H1 = intersect(A, H) if check: lhs = set(str(x) for x in span(A)) rhs = set(str(x) for x in span(H)) meet = lhs.intersection(rhs) assert meet == set(str(x) for x in span(H1)) return H1
def get_puncture(M, k): "k-puncture the rowspace of M" m, n = M.shape assert 0<=k<=n mask = [1]*k + [0]*(n-k) I = identity2(n) while 1: shuffle(mask) A = I[list(idx for idx in range(n) if mask[idx])] assert A.shape == (k, n) AM = intersect(A, M) if len(AM) == 0: break #print("get_puncture:", mask) #print(M, k) idxs = [i for i in range(n) if mask[i]] return idxs
def test_code(Hxi, Hzi, Hx, Lx, Lz, Lx0, Lx1, LxiHx, **kw): code = CSSCode(Hx=Hxi, Hz=Hzi) print(code) assert rank(intersect(Lx, code.Lx)) == code.k assert rank(intersect(Lz, code.Lz)) == code.k verbose = argv.verbose decoder = get_decoder(argv, argv.decode, code) if decoder is None: return p = argv.get("p", 0.01) N = argv.get("N", 0) distance = code.n count = 0 failcount = 0 nonuniq = 0 logops = [] for i in range(N): err_op = ra.binomial(1, p, (code.n, )) err_op = err_op.astype(numpy.int32) op = decoder.decode(p, err_op, verbose=verbose, argv=argv) c = 'F' success = False if op is not None: op = (op + err_op) % 2 # Should be a codeword of Hz (kernel of Hz) assert dot2(code.Hz, op).sum() == 0 write("%d:" % op.sum()) # Are we in the image of Hx ? If so, then success. success = dot2(code.Lz, op).sum() == 0 if success and op.sum(): nonuniq += 1 c = '.' if success else 'x' if op.sum() and not success: distance = min(distance, op.sum()) write("L") logops.append(op.copy()) else: failcount += 1 write(c + ' ') count += success if N: print() print(argv) print("error rate = %.8f" % (1. - 1. * count / (i + 1))) print("fail rate = %.8f" % (1. * failcount / (i + 1))) print("nonuniq = %d" % nonuniq) print("distance <= %d" % distance) mx0, mx1 = len(Lx0), len(Lx1) LxHx = numpy.concatenate((Lx0, Lx1, Hx)) for op in logops: print(op.sum()) #print(shortstr(op)) #print(op.shape) #print(op) K = solve(LxHx.transpose(), op) K.shape = (1, len(K)) print(shortstrx(K[:, :mx0], K[:, mx0:mx0 + mx1], K[:, mx0 + mx1:]))
def test_puncture(A, B, ma, na, mb, nb, Ina, Ima, Inb, Imb, ka, kb, kat, kbt, k, KerA, KerB, CokerA, CokerB, Lzi, Lxi, Hzi, Hxi, **kw): I = identity2 assert ka - na + ma -kat == 0 assert kb - nb + mb -kbt == 0 #print("ka=%s, kat=%s, kb=%s, kbt=%s"%(ka, kat, kb, kbt)) assert k == ka*kbt + kat*kb == len(Lzi) == len(Lxi) kernel = lambda X : find_kernel(X).transpose() # use convention in paper KerA = KerA.transpose() # use convention in paper KerB = KerB.transpose() # use convention in paper #CokerA = CokerA.transpose() # use convention in paper #CokerB = CokerB.transpose() # use convention in paper assert CokerA.shape == (kat, ma) assert CokerB.shape == (kbt, mb) blocks = [ [kron(KerA, Imb), zeros2(na*mb, ma*kb), kron(Ina, B)], [zeros2(ma*nb, ka*mb), kron(Ima, KerB), kron(A,Inb)], ] print("blocks:", [[X.shape for X in row] for row in blocks]) #print(shortstrx(*blocks[0])) #print() #print(shortstrx(*blocks[1])) Mv = cat((blocks[0][0], blocks[0][2]), axis=1) Mh = cat((blocks[1][0], blocks[1][2]), axis=1) M = cat((Mv, Mh), axis=0) KM = kernel(M) Mv = cat(blocks[0], axis=1) Mh = cat(blocks[1], axis=1) M = cat((Mv, Mh), axis=0) x = kron(I(ka), B) dot2(blocks[0][0], x) y = zeros2(blocks[0][1].shape[1], x.shape[1]) dot2(blocks[0][1], y) z = kron(KerA, I(nb)) dot2(blocks[0][2], z) #print(shortstr(x)+'\n') #print(shortstr(y)+'\n') #print(shortstr(z)+'\n') xz = cat((x, z), axis=0) xyz = cat((x, y, z), axis=0) assert dot2(M, xyz).sum() == 0 #print(shortstr(xyz)) print("xyz:", xyz.shape) assert len(find_kernel(xyz))==0 assert rowspan_eq(KM.transpose(), xz.transpose()) print("kernel(M):", kernel(M).shape) Hzt = cat((blocks[0][2], blocks[1][2]), axis=0) #print("kernel(Hzt):", kernel(Hzt).shape) KHzt = kernel(Hzt) #assert KHzt.shape[1] == 0, (KHzt.shape,) print("kernel(Hzt):", KHzt.shape) Hx = cat((kron(A, I(mb)), kron(I(ma), B)), axis=1) #print("CokerB") #print(shortstr(CokerB)) #R = CokerB #R = rand2(CokerB.shape[0], CokerB.shape[1]) #R = rand2(mb, 1) #R = CokerB[:, 0:1] if argv.puncture and 1: idxs = get_puncture(B.transpose(), kbt) print("get_puncture:", idxs) R = zeros2(mb, len(idxs)) for i, idx in enumerate(idxs): R[idx, i] = 1 elif argv.puncture: idxs = get_puncture(B.transpose(), kbt) R = zeros2(mb, 1) R[idxs] = 1 elif argv.identity2: R = I(mb) else: R = B[:, :1] #R = rand2(mb, 100) print("R:") print(shortstrx(R)) lzt = cat((kron(KerA, R), zeros2(ma*nb, KerA.shape[1]*R.shape[1])), axis=0) print("lzt:", lzt.shape) print(shortstrx(lzt)) print("Hzt:", Hzt.shape) print(shortstrx(Hzt)) assert dot2(Hx, lzt).sum()==0 lz = lzt.transpose() Hz = Hzt.transpose() print(rank(lz), rank(Hz), rank(intersect(lz, Hz))) result = rowspan_le(lzt.transpose(), Hzt.transpose()) print("lzt <= Hzt:", result) if argv.puncture: assert not result assert not rank(intersect(lz, Hz)) print("OK")
def test(A, B, ma, na, mb, nb, Ina, Ima, Inb, Imb, ka, kb, kat, kbt, k, KerA, KerB, CokerA, CokerB, Lzi, Lxi, Hzi, Hxi, **kw): #print("ka=%s, kat=%s, kb=%s, kbt=%s"%(ka, kat, kb, kbt)) assert k == ka*kbt + kat*kb == len(Lzi) == len(Lxi) KerA = KerA.transpose() # use convention in paper KerB = KerB.transpose() # use convention in paper CokerA = CokerA.transpose() # use convention in paper CokerB = CokerB.transpose() # use convention in paper blocks = [ [kron(KerA, Imb), zeros2(na*mb, ma*kb), kron(Ina, B)], [zeros2(ma*nb, ka*mb), kron(Ima, KerB), kron(A,Inb)], ] print("blocks:", [[X.shape for X in row] for row in blocks]) #print(shortstrx(*blocks[0])) #print() #print(shortstrx(*blocks[1])) Hzt = cat((blocks[0][2], blocks[1][2]), axis=0) K = find_kernel(Hzt) assert len(K) == ka*kb # see proof of Lemma 3 Lzv = cat((blocks[0][0], blocks[1][0])).transpose() Lzh = cat((blocks[0][1], blocks[1][1])).transpose() assert dot2(Hxi, Lzv.transpose()).sum() == 0 # Hz = Hzt.transpose() # Hzi = linear_independent(Hz) # Lzhi = independent_logops(Lzh, Hzi, verbose=True) # print("Lzhi:", Lzhi.shape) # -------------------------------------------------------- # basis for all logops, including stabilizers lz = find_kernel(Hxi) # returns transpose of kernel #lz = rand_rowspan(lz) #print("lz:", lz.shape) assert len(lz) == k+len(Hzi) # vertical qubits Iv = cat((identity2(na*mb), zeros2(ma*nb, na*mb)), axis=0).transpose() # horizontal qubits Ih = cat((zeros2(na*mb, ma*nb), identity2(ma*nb)), axis=0).transpose() assert len(intersect(Iv, Ih))==0 # sanity check # now restrict these logops to vertical qubits #print("Iv:", Iv.shape) lzv = intersect(Iv, lz) #print("lzv:", lzv.shape) J = intersect(lzv, Lzv) assert len(J) == len(lzv) # -------------------------------------------------------- # now we manually build _lz supported on vertical qubits x = rand2(ka*mb, ka*nb) y = kron(KerA, Inb) assert eq2(dot2(blocks[0][2], y), kron(KerA, B)) v = (dot2(blocks[0][0], x) + dot2(blocks[0][2], y)) % 2 h = zeros2(ma*nb, v.shape[1]) _lzt = cat((v, h)) assert dot2(Hxi, _lzt).sum() == 0 #print(shortstr(_lzt)) _lz = _lzt.transpose() _lz = linear_independent(_lz) #print("*"*(na*mb)) #print(shortstr(_lz)) assert len(intersect(_lz, Ih)) == 0 assert len(intersect(_lz, Iv)) == len(_lz) J = intersect(_lz, lz) assert len(J) == len(_lz) J = intersect(_lz, Lzv) #print(J.shape, _lz.shape, Lzv.shape) assert len(J) == len(_lz) if 0: V = cat(blocks[0][:2], axis=1) H = cat(blocks[1][:2], axis=1) X = cat((V, H), axis=0) K = find_kernel(X) print(K.shape) V = cat(blocks[0], axis=1) H = cat(blocks[1], axis=1) X = cat((V, H), axis=0) K = find_kernel(X) print(K.shape) #print("-"*(ka*mb+ma*kb)) I = cat((identity2(ka*mb+ma*kb), zeros2(ka*mb+ma*kb, na*nb)), axis=1) J = intersect(K, I) print("J:", J.shape)
def hypergraph_product(C, D, check=False): print("hypergraph_product:", C.shape, D.shape) print("distance:", classical_distance(C)) c0, c1 = C.shape d0, d1 = D.shape E1 = identity2(c0) E2 = identity2(d0) M1 = identity2(c1) M2 = identity2(d1) Hx0 = kron(M1, D.transpose()), kron(C.transpose(), M2) Hx = numpy.concatenate(Hx0, axis=1) # horizontal concatenate Hz0 = kron(C, E2), kron(E1, D) #print("Hz0:", Hz0[0].shape, Hz0[1].shape) Hz = numpy.concatenate(Hz0, axis=1) # horizontal concatenate assert dot2(Hz, Hx.transpose()).sum() == 0 n = Hx.shape[1] assert Hz.shape[1] == n # --------------------------------------------------- # Build Lx KerC = find_kernel(C) #KerC = min_span(KerC) # does not seem to matter... ?? assert KerC.shape[1] == c1 K = KerC.transpose() E = identity2(d0) print(shortstr(KerC)) print() Lxt0 = kron(K, E), zeros2(c0 * d1, K.shape[1] * d0) Lxt0 = numpy.concatenate(Lxt0, axis=0) assert dot2(Hz, Lxt0).sum() == 0 K = find_kernel(D).transpose() assert K.shape[0] == d1 E = identity2(c0) Lxt1 = zeros2(c1 * d0, K.shape[1] * c0), kron(E, K) Lxt1 = numpy.concatenate(Lxt1, axis=0) assert dot2(Hz, Lxt1).sum() == 0 Lxt = numpy.concatenate((Lxt0, Lxt1), axis=1) # horizontal concatenate Lx = Lxt.transpose() assert dot2(Hz, Lxt).sum() == 0 # These are linearly dependent, but # once we add stabilizers it will be reduced: assert rank(Lx) == len(Lx) if 0: # --------------------------------------------------- print(shortstr(Lx)) k = get_k(Lx, Hx) print("k =", k) left, right = [], [] draw = Draw(c0, c1, d0, d1) cols = [] for j in range(d0): # col col = [] for i in range(len(KerC)): # logical op = Lx[i * d0 + j] col.append(op) if j == i: draw.mark_xop(op) if j < d0 / 2: left.append(op) else: right.append(op) cols.append(col) draw.save("output.2") left = array2(left) right = array2(right) print(get_k(left, Hx)) print(get_k(right, Hx)) return # --------------------------------------------------- # Build Lz counit = lambda n: unit2(n).transpose() K = find_cokernel(D) # matrix of row vectors #E = counit(c1) E = identity2(c1) Lz0 = kron(E, K), zeros2(K.shape[0] * c1, c0 * d1) Lz0 = numpy.concatenate(Lz0, axis=1) # horizontal concatenate assert dot2(Lz0, Hx.transpose()).sum() == 0 K = find_cokernel(C) #E = counit(d1) E = identity2(d1) Lz1 = zeros2(K.shape[0] * d1, c1 * d0), kron(K, E) Lz1 = numpy.concatenate(Lz1, axis=1) # horizontal concatenate Lz = numpy.concatenate((Lz0, Lz1), axis=0) assert dot2(Lz, Hx.transpose()).sum() == 0 overlap = 0 for lx in Lx: for lz in Lz: w = (lx * lz).sum() overlap = max(overlap, w) assert overlap <= 1, overlap #print("max overlap:", overlap) assert rank(Hx) == len(Hx) assert rank(Hz) == len(Hz) mx = len(Hx) mz = len(Hz) # --------------------------------------------------- Lxs = [] for op in Lx: op = (op + Hx) % 2 Lxs.append(op) LxHx = numpy.concatenate(Lxs) LxHx = row_reduce(LxHx) print("LxHx:", len(LxHx)) assert LxHx.shape[1] == n print(len(intersect(LxHx, Hx)), mx) assert len(intersect(LxHx, Hx)) == mx Lzs = [] for op in Lz: op = (op + Hz) % 2 Lzs.append(op) LzHz = numpy.concatenate(Lzs) LzHz = row_reduce(LzHz) print("LzHz:", len(LzHz)) assert LzHz.shape[1] == n # --------------------------------------------------- # Remove excess logops. # print("remove_dependent") # # # -------- Lx # # Lx, Lx1 = mk_disjoint_logops(Lx, Hx) # # # -------- Lz # # Lz, Lz1 = mk_disjoint_logops(Lz, Hz) # -------------------------------- # independent_logops for Lx k = get_k(Lx, Hx) idxs0, idxs1 = [], [] for j in range(d0): # col for i in range(c1): idx = j + i * d0 if j < d0 // 2: idxs0.append(idx) else: idxs1.append(idx) Lx0 = in_support(LxHx, idxs0) Lx0 = independent_logops(Lx0, Hx) k0 = (len(Lx0)) Lx1 = in_support(LxHx, idxs1) Lx1 = independent_logops(Lx1, Hx) k1 = (len(Lx1)) assert k0 == k1 == k, (k0, k1, k) # -------------------------------- # independent_logops for Lz idxs0, idxs1 = [], [] for j in range(d0): # col for i in range(c1): idx = j + i * d0 if i < c1 // 2: idxs0.append(idx) else: idxs1.append(idx) Lz0 = in_support(LzHz, idxs0) Lz0 = independent_logops(Lz0, Hz) k0 = (len(Lz0)) Lz1 = in_support(LzHz, idxs1) Lz1 = independent_logops(Lz1, Hz) k1 = (len(Lz1)) assert k0 == k1 == k, (k0, k1, k) # --------------------------------------------------- # #assert eq2(dot2(Lz, Lxt), identity2(k)) assert mx + mz + k == n print("mx = %d, mz = %d, k = %d : n = %d" % (mx, mz, k, n)) # --------------------------------------------------- # # if Lx1 is None: # return # # if Lz1 is None: # return # --------------------------------------------------- # op = zeros2(n) for lx in Lx0: for lz in Lz0: lxz = lx * lz #print(lxz) #print(op.shape, lxz.shape) op += lxz for lx in Lx1: for lz in Lz1: lxz = lx * lz #print(lxz) #print(op.shape, lxz.shape) op += lxz idxs = numpy.where(op)[0] print("correctable region size = %d" % len(idxs)) #print(op) #print(idxs) Lx, Lz = Lx0, Lz0 Lxs = [] for op in Lx: op = (op + Hx) % 2 Lxs.append(op) LxHx = numpy.concatenate(Lxs) LxHx = row_reduce(LxHx) assert LxHx.shape[1] == n Lzs = [] for op in Lz: op = (op + Hz) % 2 Lzs.append(op) LzHz = numpy.concatenate(Lzs) LzHz = row_reduce(LzHz) assert LzHz.shape[1] == n if argv.draw: do_draw(**locals()) good = is_correctable(n, idxs, LxHx, LzHz) assert good print("good") # --------------------------------------------------- # if argv.code: print("code = CSSCode()") code = CSSCode(Hx=Hx, Hz=Hz, Lx=Lx, Lz=Lz, check=True, verbose=False, build=True) print(code) #print(code.weightstr()) if check: U = solve(Lx.transpose(), code.Lx.transpose()) assert U is not None #print(U.shape) assert eq2(dot2(U.transpose(), Lx), code.Lx) #print(shortstr(U)) if 0: Lx, Lz = code.Lx, code.Lz print("Lx:", Lx.shape) print(shortstr(Lx)) print("Lz:", Lz.shape) print(shortstr(Lz))