def set_z(self, z1, z1_2, muls): assert len( muls) == 2, "Got muls of wrong size with non-None z1_2 in set_z()" self.roundNum = 0 self.compute_z1chi = VerifierIOMLExt.compute_beta( z1, self.circuit.comp_chi, muls[0]) self.compute_z1chi_2 = VerifierIOMLExt.compute_beta( z1_2, self.circuit.comp_chi, muls[1]) # loop over all the gates and make them update their z coeffs for g in self.gates: g.set_z()
def set_rvals(self, rvals, r0val): self.r0val = r0val if self.nbits is not None: assert len(rvals) == self.nbits else: self.nbits = len(rvals) self.v1bits = self.nbits // 2 self.v2bits = self.nbits - self.v1bits self.v1vals = VerifierIOMLExt.compute_beta(rvals[:self.v1bits], self.com.rec_q) self.v2vals = VerifierIOMLExt.compute_beta(rvals[self.v1bits:], self.com.rec_q)
def set_rvals_p(self, rvals, r0val, rZval): assert self.nbits == len(rvals) if self.v1bits > 0: mvals = VerifierIOMLExt.compute_beta(rvals[:self.v1bits], self.com.rec_q) assert len(mvals) == len(self.tvals) assert len(mvals) == len(self.svals) self.avals = util.vector_times_matrix(self.tvals, mvals, self.com.rec_q) self.rAval = util.dot_product(self.svals, mvals, self.com.rec_q) else: self.avals = self.tvals[0] self.rAval = self.svals[0] self.bvals = VerifierIOMLExt.compute_beta(rvals[self.v1bits:], self.com.rec_q, r0val) self.rZval = rZval
def speed_test(num_tests): nBits = random.randint(3, 8) inputs = [ [ Defs.gen_random() for _ in xrange(0, nBits) ] for _ in xrange(0, num_tests) ] lcb = LayerComputeBeta(nBits) lcb.other_factors = [] runtime = time.time() for idx in xrange(0, num_tests): lcb.set_inputs(inputs[idx]) runtime = time.time() - runtime runtime2 = time.time() for idx in xrange(0, num_tests): VerifierIOMLExt.compute_beta(inputs[idx]) runtime2 = time.time() - runtime2 print "nBits: %d\nLayerComputeBeta: %f\nVerifierIOMLExt: %f\n" % (nBits, runtime, runtime2)
def run_test(): # pylint: disable=global-variable-undefined,redefined-outer-name tinputs = [Defs.gen_random() for _ in xrange(0, nOutBits)] taus = [Defs.gen_random() for _ in xrange(0, nOutBits)] lcv.set_inputs(tinputs) assert lcv.outputs == VerifierIOMLExt.compute_beta(tinputs) inputs = [ util.chi(util.numToBin(x, nOutBits), tinputs) for x in xrange(0, 2**nOutBits) ] global scratch global outputs scratch = list(inputs) outputs = list(inputs) def compute_next_value(tau): global scratch global outputs nscratch = [] tauInv = (1 - tau) % Defs.prime for i in xrange(0, len(scratch) / 2): val = ((scratch[2 * i] * tauInv) + (scratch[2 * i + 1] * tau)) % Defs.prime nscratch.append(val) del val scratch = nscratch #ndups = len(outputs) / len(scratch) #nouts = [ [val] * ndups for val in scratch ] outputs = scratch #outputs = [item for sublist in nouts for item in sublist] for i in xrange(0, nOutBits): assert lcv.inputs == inputs assert lcv.outputs == outputs assert lcv.scratch == scratch compute_next_value(taus[i]) lcv.next_round(taus[i]) assert outputs == lcv.outputs assert scratch == lcv.scratch assert lcv.prevPassValue == scratch[0] assert all([lcv.prevPassValue == elm[0] for elm in lcv.outputs_fact])
def set_rvals_v(self, rvals, r0val, Avals, Zval, vxeval): self.nbits = len(rvals) if self.bitdiv == 0: self.v1bits = 0 else: self.v1bits = int(self.nbits / self.bitdiv) self.v2bits = self.nbits - self.v1bits self.rvals = rvals[self.v1bits:] self.r0val = r0val if self.v1bits == 0: Pval = Avals[0] else: mvals = VerifierIOMLExt.compute_beta(rvals[:self.v1bits], self.com.rec_q) assert len(Avals) == len(mvals) Pval = self.gops.multiexp(Avals, mvals) if self.com.rec: self.com.rec_p.did_mexp(len(mvals)) self.Pvals = [self.gops.mul(Pval, self.gops.maul(Zval, vxeval))] self.cvals = []
def speed_test(num_tests): nbits = random.randint(3, 8) taus = [Defs.gen_random() for _ in xrange(0, nbits)] inputs = [[Defs.gen_random() for _ in xrange(0, 2**nbits)] for _ in xrange(0, num_tests)] vim = VerifierIOMLExt(taus) runtime = time.time() for idx in xrange(0, num_tests): vim.compute_nosavebits(inputs[idx]) runtime = time.time() - runtime runtime2 = time.time() for idx in xrange(0, num_tests): vim.compute_savebits(inputs[idx]) runtime2 = time.time() - runtime2 runtime3 = time.time() for idx in xrange(0, num_tests): vim.compute_sqrtbits(inputs[idx]) runtime3 = time.time() - runtime3 print "nBits: %d\nnosavebits: %f\nsavebits: %f\nsqrtbits: %f\n" % ( nbits, runtime, runtime2, runtime3)
def run_vecpoly_helper(com): wcom = commit.WitnessCommit(com) # generate a random vector and random point and compute mlext rlen = random.randint(4, 10) rvals = [com.gops.rand_scalar() for _ in xrange(0, rlen)] wvals = [com.gops.rand_scalar() for _ in xrange(0, 2**rlen)] zeta_val = VerifierIOMLExt(rvals).compute(wvals) szeta = com.gops.rand_scalar() zeta = com.gops.pow_gh(zeta_val, szeta) # commit to the witness cvals = wcom.witness_commit(wvals) wcom.set_rvals(rvals, 1) (aval, Cval) = wcom.eval_init() # V challenge chal = com.gops.rand_scalar() (zvals, zh, zc) = wcom.eval_finish(chal, szeta) # now V checks wcom2 = commit.WitnessCommit(com) wcom2.set_rvals(rvals, 1) assert wcom2.eval_check(cvals, aval, Cval, zvals, zh, zc, chal, zeta, 0)
def run_one_test(nbits, squawk, nbins, pattern): z = [Defs.gen_random() for _ in xrange(0, nbits)] inv = [Defs.gen_random() for _ in xrange(0, (2**nbits) - nbins)] if pattern is 0: inv += [0 for _ in xrange(0, nbins)] elif pattern is 1: inv += [1 for _ in xrange(0, nbins)] elif pattern == 2: inv += [(i % 2) for i in xrange(0, nbins)] elif pattern == 3: inv += [((i + 1) % 2) for i in xrange(0, nbins)] else: inv += [random.randint(0, 1) for _ in xrange(0, nbins)] assert len(inv) == (2**nbits) Defs.track_fArith = True fa = Defs.fArith() oldrec = fa.new_cat("old") newrec = fa.new_cat("new") nw2rec = fa.new_cat("nw2") oldbeta = LayerComputeBeta(nbits, z, oldrec) oldval = sum(util.mul_vecs(oldbeta.outputs, inv)) % Defs.prime oldrec.did_mul(len(inv)) oldrec.did_add(len(inv) - 1) newcomp = VerifierIOMLExt(z, newrec) newval = newcomp.compute(inv) nw2comp = LayerComputeV(nbits, nw2rec) nw2comp.other_factors = [] nw2comp.set_inputs(inv) for zz in z: nw2comp.next_round(zz) nw2val = nw2comp.prevPassValue assert oldval == newval, "error for inputs (new) %s : %s" % (str(z), str(inv)) assert oldval == nw2val, "error for inputs (nw2) %s : %s" % (str(z), str(inv)) if squawk: print print "nbits: %d" % nbits print "OLD: %s" % str(oldrec) print "NEW: %s" % str(newrec) print "NW2: %s" % str(nw2rec) betacomp = VerifierIOMLExt.compute_beta(z) beta_lo = random.randint(0, 2**nbits - 1) beta_hi = random.randint(beta_lo, 2**nbits - 1) betacomp2 = VerifierIOMLExt.compute_beta(z, None, 1, beta_lo, beta_hi) # make sure that the right range was generated, and correctly assert len(betacomp) == len(betacomp2) assert all([b is None for b in betacomp2[:beta_lo]]) assert all([b is not None for b in betacomp2[beta_lo:beta_hi + 1]]) assert all([b is None for b in betacomp2[beta_hi + 1:]]) assert all([ b2 == b if b2 is not None else True for (b, b2) in zip(betacomp, betacomp2) ]) return newrec.get_counts()
def run(self, pf, _=None): # pylint: disable=arguments-differ self.fs = fs.FiatShamir.from_string(pf) #### # 0. Get i/o #### self.muxbits = self.fs.take(True) self.inputs = self.fs.take(True) self.outputs = self.fs.take(True) #### # 1. mlext of outs #### nOutBits = util.clog2(len(self.in0vv[-1])) assert util.clog2(len(self.outputs)) == nOutBits + self.nCopyBits # z1 and z2 vals z1 = [ self.fs.rand_scalar() for _ in xrange(0, nOutBits) ] z1_2 = None z2 = [ self.fs.rand_scalar() for _ in xrange(0, self.nCopyBits) ] if Defs.track_fArith: self.sc_a.did_rng(nOutBits + self.nCopyBits) # instructions for P muls = None project_line = len(self.in0vv) == 1 expectNext = VerifierIOMLExt(z1 + z2, self.out_a).compute(self.outputs) #### # 2. Simulate prover interactions #### for lay in xrange(0, len(self.in0vv)): nInBits = self.layInBits[lay] nOutBits = self.layOutBits[lay] w1 = [] w2 = [] w3 = [] if Defs.track_fArith: self.sc_a.did_rng(2*nInBits + self.nCopyBits) #### # A. Sumcheck #### for rd in xrange(0, 2 * nInBits + self.nCopyBits): outs = self.fs.take() gotVal = (outs[0] + sum(outs)) % Defs.prime if Defs.track_fArith: self.sc_a.did_add(len(outs)) assert expectNext == gotVal, "Verification failed in round %d of layer %d" % (rd, lay) nrand = self.fs.rand_scalar() expectNext = util.horner_eval(outs, nrand, self.sc_a) if rd < self.nCopyBits: assert len(outs) == 4 w3.append(nrand) else: assert len(outs) == 3 if rd < self.nCopyBits + nInBits: w1.append(nrand) else: w2.append(nrand) outs = self.fs.take() if project_line: assert len(outs) == 1 + nInBits v1 = outs[0] % Defs.prime v2 = sum(outs) % Defs.prime if Defs.track_fArith: self.tV_a.did_add(len(outs)-1) else: assert len(outs) == 2 v1 = outs[0] v2 = outs[1] #### # B. mlext of wiring predicate #### tV_eval = self.eval_tV(lay, z1, z2, w1, w2, w3, v1, v2, z1_2, muls) assert expectNext == tV_eval, "Verification failed computing tV for layer %d" % lay #### # C. Extend to next layer #### project_next = lay == len(self.in0vv) - 2 if project_line: tau = self.fs.rand_scalar() muls = None expectNext = util.horner_eval(outs, tau, self.nlay_a) # z1 = w1 + ( w2 - w1 ) * tau z1 = [ (elm1 + (elm2 - elm1) * tau) % Defs.prime for (elm1, elm2) in izip(w1, w2) ] z1_2 = None if Defs.track_fArith: self.nlay_a.did_sub(len(w1)) self.nlay_a.did_mul(len(w1)) self.nlay_a.did_add(len(w1)) self.sc_a.did_rng() else: muls = [self.fs.rand_scalar(), self.fs.rand_scalar()] tau = None expectNext = ( muls[0] * v1 + muls[1] * v2 ) % Defs.prime z1 = w1 z1_2 = w2 if Defs.track_fArith: self.nlay_a.did_add() self.nlay_a.did_mul(2) self.sc_a.did_rng(2) project_line = project_next z2 = w3 #### # 3. mlext of inputs #### input_mlext_eval = VerifierIOMLExt(z1 + z2, self.in_a).compute(self.inputs) assert input_mlext_eval == expectNext, "Verification failed checking input mlext"
class WitnessLogCommitShort(_WCBase): # pylint: disable=super-init-not-called avals = None bvals = None cvals = None gvals = None rvals = None r0val = None rPval = None rLval = None rRval = None Pvals = None dval = None rdelta = None rbeta = None def __init__(self, com, bitdiv=0): self.com = com self.gops = com.gops if bitdiv < 2: self.bitdiv = 0 else: self.bitdiv = bitdiv def set_rvals_p(self, rvals, r0val, rZval): assert self.nbits == len(rvals) if self.v1bits > 0: mvals = VerifierIOMLExt.compute_beta(rvals[:self.v1bits], self.com.rec_q) assert len(mvals) == len(self.tvals) assert len(mvals) == len(self.svals) self.avals = util.vector_times_matrix(self.tvals, mvals, self.com.rec_q) self.rPval = util.dot_product(self.svals, mvals, self.com.rec_q) else: self.avals = self.tvals[0] self.rPval = self.svals[0] self.bvals = VerifierIOMLExt.compute_beta(rvals[self.v1bits:], self.com.rec_q, r0val) self.rPval += rZval if self.com.rec: self.com.rec_q.did_add() def set_rvals_v(self, rvals, r0val, Avals, Zval, vxeval): self.nbits = len(rvals) if self.bitdiv == 0: self.v1bits = 0 else: self.v1bits = int(self.nbits / self.bitdiv) self.v2bits = self.nbits - self.v1bits self.rvals = rvals[self.v1bits:] self.r0val = r0val if self.v1bits == 0: Pval = Avals[0] else: mvals = VerifierIOMLExt.compute_beta(rvals[:self.v1bits], self.com.rec_q) assert len(Avals) == len(mvals) Pval = self.gops.multiexp(Avals, mvals) if self.com.rec: self.com.rec_p.did_mexp(len(mvals)) self.Pvals = [self.gops.mul(Pval, self.gops.maul(Zval, vxeval))] self.cvals = [] def redc_init(self): assert self.rLval is None assert self.rRval is None assert len(self.avals) == len(self.bvals) nprime = len(self.avals) // 2 self.rLval = self.gops.rand_scalar() self.rRval = self.gops.rand_scalar() cL = sum( util.mul_vecs(self.avals[:nprime], self.bvals[nprime:], self.com.rec_q)) % self.gops.q cR = sum( util.mul_vecs(self.avals[nprime:], self.bvals[:nprime], self.com.rec_q)) % self.gops.q # g2^a1, g1^a2 if self.gvals is None: Lval = self.gops.maul( self.gops.pow_gih(self.avals[:nprime], self.rLval, nprime), self.gops.q - cL) Rval = self.gops.maul( self.gops.pow_gih(self.avals[nprime:], self.rRval, 0), self.gops.q - cR) else: Lval = self.gops.multiexp( self.gvals[nprime:] + [self.gops.g, self.gops.h], self.avals[:nprime] + [cL, self.rLval]) Rval = self.gops.multiexp( self.gvals[:nprime] + [self.gops.g, self.gops.h], self.avals[nprime:] + [cR, self.rRval]) if self.com.rec: self.com.rec_q.did_rng(2) self.com.rec_p.did_mexps([2 + nprime] * 2) return (Lval, Rval) def _collapse_vec(self, v, c, c2): nprime = len(v) // 2 ret = [] for (v1, v2) in izip(v[:nprime], v[nprime:]): ret.append((v1 * c + v2 * c2) % self.gops.q) assert len(ret) == nprime if self.com.rec: self.com.rec_q.did_mul(len(v)) self.com.rec_q.did_add(nprime) return ret def _collapse_gvec(self, v, c, c2): ret = [] if v is None: nprime = 2**(self.v2bits - 1) for idx in xrange(0, nprime): ret.append(self.gops.pow_gij(idx, idx + nprime, c, c2)) else: nprime = len(v) // 2 for (v1, v2) in izip(v[:nprime], v[nprime:]): ret.append(self.gops.multiexp([v1, v2], [c, c2])) assert len(ret) == nprime if self.com.rec: self.com.rec_p.did_mexps([2] * nprime) return ret def redc_cont_p(self, c): assert self.rLval is not None assert self.rRval is not None assert len(self.avals) == len(self.bvals) assert self.rPval is not None assert self.gvals is None or len(self.gvals) == len(self.avals) cm1 = util.invert_modp(c, self.gops.q, self.com.rec_q) # compute new avals and bvals self.avals = self._collapse_vec(self.avals, c, cm1) self.bvals = self._collapse_vec(self.bvals, cm1, c) # compute new gvals self.gvals = self._collapse_gvec(self.gvals, cm1, c) # compute new rAval and rZval self.rPval += self.rLval * c * c + self.rRval * cm1 * cm1 self.rPval %= self.gops.q self.rLval = None self.rRval = None if self.com.rec: self.com.rec_q.did_inv() self.com.rec_q.did_mul(4) self.com.rec_q.did_add(2) return len(self.gvals) > 1 def redc_cont_v(self, c, LRval): assert self.Pvals is not None assert self.bvals is None assert self.gvals is None # record c, Aval, and Zval self.cvals.append(c) self.Pvals.extend(LRval) return self.v2bits != len(self.cvals) def fin_init(self): assert len(self.gvals) == 1 assert len(self.bvals) == 1 assert len(self.avals) == 1 self.dval = self.gops.rand_scalar() self.rdelta = self.gops.rand_scalar() self.rbeta = self.gops.rand_scalar() # delta and beta are g'^d and g^d, respectively delta = self.gops.multiexp([self.gvals[0], self.gops.h], [self.dval, self.rdelta]) beta = self.gops.pow_gh(self.dval, self.rbeta) if self.com.rec: self.com.rec_q.did_rng(3) self.com.rec_p.did_mexps([2, 2]) return (delta, beta) def fin_finish(self, c): z1val = (c * self.avals[0] * self.bvals[0] + self.dval) % self.gops.q z2val = ((c * self.rPval + self.rbeta) * self.bvals[0] + self.rdelta) % self.gops.q if self.com.rec: self.com.rec_q.did_add(3) self.com.rec_q.did_mul(4) return (z1val, z2val) def fin_check(self, c, (delta, beta), (z1val, z2val)): # compute inverses cprod = reduce(lambda x, y: (x * y) % self.gops.q, self.cvals) cprodinv = util.invert_modp(cprod, self.gops.q, self.com.rec_q) cinvs = [0] * len(self.cvals) for idx in xrange(0, len(self.cvals)): cvs = chain(self.cvals[:idx], self.cvals[idx + 1:]) cinvs[idx] = reduce(lambda x, y: (x * y) % self.gops.q, cvs, cprodinv) csqs = [(cval * cval) % self.gops.q for cval in self.cvals] cinvsqs = [(cval * cval) % self.gops.q for cval in cinvs] # compute powers for multiexps gpows = [cprodinv] for cval in csqs: new = [0] * 2 * len(gpows) for (idx, gpow) in enumerate(gpows): new[2 * idx] = gpow new[2 * idx + 1] = (gpow * cval) % self.gops.q gpows = new # compute powers for P commitments bval = (VerifierIOMLExt(self.rvals, self.com.rec_q).compute(gpows) * self.r0val) % self.gops.q bc = (bval * c) % self.gops.q azpows = [bc] + [(bc * cval) % self.gops.q for cval in chain.from_iterable(izip(csqs, cinvsqs))] # now compute the check values themselves gval = self.gops.pow_gi(gpows, 0) lhs = self.gops.multiexp(self.Pvals + [beta, delta], azpows + [bval, 1]) rhs = self.gops.multiexp([gval, self.gops.g, self.gops.h], [z1val, (z1val * bval) % self.gops.q, z2val]) if self.com.rec: clen = len(self.cvals) self.com.rec_p.did_mexps([3, 2 + len(self.Pvals), len(gpows)]) self.com.rec_q.did_mul( len(gpows) + (clen + 1) * (clen - 1) + 4 * clen + 2) return lhs == rhs
# compute powers for multiexps azpows = [c] + [ (c * cval) % self.gops.q for cval in chain.from_iterable(izip(cinvs, self.cvals)) ] gpows = [(cprodinv * cprodinv) % self.gops.q] for cval in self.cvals: new = [] for gpow in gpows: new.extend([(gpow * cval) % self.gops.q, gpow]) gpows = new # compute bvals stopbits = util.clog2(self.stoplen) bvinit = (VerifierIOMLExt(self.rvals[stopbits:], self.com.rec_q).compute(gpows) * self.r0val) % self.gops.q bvals = VerifierIOMLExt.compute_beta(self.rvals[:stopbits], self.com.rec_q, bvinit) # now compute the check values themselves gvals = [ self.gops.pow_gi(gpows, idx, self.stoplen) for idx in xrange(0, self.stoplen) ] lhs1 = self.gops.multiexp(self.Avals + [delta], azpows + [1]) rhs1 = self.gops.multiexp(gvals + [self.gops.h], zvals + [zdelta]) prod_bz = sum(util.mul_vecs(bvals, zvals, self.com.rec_q)) % self.gops.q lhs2 = self.gops.multiexp(self.Zvals + [beta], azpows + [1])
def run(self, pf, _=None): # pylint: disable=arguments-differ assert Defs.prime == self.com.gops.q self.fs = fs.FiatShamir.from_string(pf) assert Defs.prime == self.fs.q #### # 0. Get i/o #### self.muxbits = self.fs.take(True) self.inputs = self.fs.take(True) # get witness commitments nd_cvals = [] if self.fs.ndb is not None: num_vals = 2**(self.nInBits - self.fs.ndb) nCopies = 1 if self.rdl is None: nCopies = self.nCopies for copy in xrange(0, nCopies): (cvals, is_ok) = self.check_pok(num_vals) if not is_ok: raise ValueError( "Failed getting commitments to input for copy %d" % copy) if self.rdl is None: nd_cvals.append(cvals) else: nd_cvals.extend(cvals) # now generate rvals if self.fs.rvstart is not None and self.fs.rvend is not None: r_values = [ self.fs.rand_scalar() for _ in xrange(self.fs.rvstart, self.fs.rvend + 1) ] nCopies = 1 if self.rdl is None: nCopies = self.nCopies for idx in xrange(0, nCopies): first = idx * (2**self.nInBits) + self.fs.rvstart last = first + self.fs.rvend - self.fs.rvstart + 1 self.inputs[first:last] = r_values # finally, get outputs self.outputs = self.fs.take(True) #### # 1. mlext of outs #### nOutBits = util.clog2(len(self.in0vv[-1])) assert util.clog2(len(self.outputs)) == nOutBits + self.nCopyBits # z1 and z2 vals z1 = [self.fs.rand_scalar() for _ in xrange(0, nOutBits)] z1_2 = None z2 = [self.fs.rand_scalar() for _ in xrange(0, self.nCopyBits)] if Defs.track_fArith: self.sc_a.did_rng(nOutBits + self.nCopyBits) # instructions for P muls = None project_line = len(self.in0vv) == 1 expectNext = VerifierIOMLExt(z1 + z2, self.out_a).compute(self.outputs) prev_cval = None #### # 2. Simulate prover interactions #### for lay in xrange(0, len(self.in0vv)): nInBits = self.layInBits[lay] nOutBits = self.layOutBits[lay] w1 = [] w2 = [] w3 = [] if Defs.track_fArith: self.sc_a.did_rng(2 * nInBits + self.nCopyBits) #### # A. Sumcheck #### for rd in xrange(0, 2 * nInBits + self.nCopyBits): if rd < self.nCopyBits: nelms = 4 else: nelms = 3 (cvals, is_ok) = self.check_pok(nelms) if not is_ok: raise ValueError( "PoK failed for commits in round %d of layer %d" % (rd, lay)) ncom = self.com.zero_plus_one_eval(cvals) if prev_cval is None: is_ok = self.check_val_proof(ncom, expectNext) else: is_ok = self.check_eq_proof(prev_cval, ncom) if not is_ok: raise ValueError( "Verification failed in round %d of layer %d" % (rd, lay)) nrand = self.fs.rand_scalar() prev_cval = self.com.horner_eval(cvals, nrand) if rd < self.nCopyBits: w3.append(nrand) elif rd < self.nCopyBits + nInBits: w1.append(nrand) else: w2.append(nrand) #### # B. Extend to next layer #### if project_line: assert lay == len(self.in0vv) - 1 (cvals, c2val, c3val, is_ok) = self.check_final_prod_pok(nInBits) if not is_ok: raise ValueError( "Verification of final product PoK failed") pr_cvals = (cvals[0], c2val, c3val) else: (pr_cvals, is_ok) = self.check_prod_pok() if not is_ok: raise ValueError( "Verification of product PoK failed in layer %d" % lay) # check final val with mlext eval (mlext_evals, mlx_z2) = self.eval_mlext(lay, z1, z2, w1, w2, w3, z1_2, muls) tV_cval = self.com.tV_eval(pr_cvals, mlext_evals, mlx_z2) is_ok = self.check_eq_proof(prev_cval, tV_cval) if not is_ok: raise ValueError( "Verification of mlext eq proof failed in layer %d" % lay) project_next = lay == len(self.in0vv) - 2 if project_line: tau = self.fs.rand_scalar() muls = None prev_cval = self.com.horner_eval(cvals, tau) z1 = [(elm1 + (elm2 - elm1) * tau) % Defs.prime for (elm1, elm2) in izip(w1, w2)] z1_2 = None if Defs.track_fArith: self.nlay_a.did_sub(len(w1)) self.nlay_a.did_mul(len(w1)) self.nlay_a.did_add(len(w1)) self.sc_a.did_rng() else: muls = [self.fs.rand_scalar(), self.fs.rand_scalar()] tau = None prev_cval = self.com.muls_eval(pr_cvals, muls) z1 = w1 z1_2 = w2 if Defs.track_fArith: self.sc_a.did_rng(2) project_line = project_next z2 = w3 #### # 3. mlext of inputs #### if self.rdl is None: fin_inputs = self.inputs else: fin_inputs = [] for r_ents in self.rdl: fin_inputs.extend(self.inputs[r_ent] for r_ent in r_ents) input_mlext_eval = VerifierIOMLExt(z1 + z2, self.in_a).compute(fin_inputs) if len(nd_cvals) is 0 or self.fs.ndb is None: is_ok = self.check_val_proof(prev_cval, input_mlext_eval) elif self.rdl is None: copy_vals = VerifierIOMLExt.compute_beta(z2, self.in_a) loIdx = (2**self.nInBits) - (2**(self.nInBits - self.fs.ndb)) hiIdx = (2**self.nInBits) - 1 gate_vals = VerifierIOMLExt.compute_beta(z1, self.in_a, 1, loIdx, hiIdx) num_nd = 2**(self.nInBits - self.fs.ndb) cval_acc = self.com.accum_init(input_mlext_eval) for (cidx, vals) in enumerate(nd_cvals): copy_mul = copy_vals[cidx] assert len(vals) == num_nd for (gidx, val) in enumerate(vals, start=loIdx): exp = (copy_mul * gate_vals[gidx]) % Defs.prime cval_acc = self.com.accum_add(cval_acc, val, exp) if Defs.track_fArith: self.com_q_a.did_mul(len(vals)) fin_cval = self.com.accum_finish(cval_acc) is_ok = self.check_eq_proof(prev_cval, fin_cval) else: beta_vals = VerifierIOMLExt.compute_beta(z1 + z2, self.in_a) loIdx = (2**self.nInBits) - (2**(self.nInBits - self.fs.ndb)) perCkt = 2**self.nCktBits nd_cvals.append(self.com.gops.g) exps = [0] * len(nd_cvals) exps[-1] = input_mlext_eval for (cidx, r_ents) in enumerate(self.rdl): for (gidx, r_ent) in enumerate(r_ents): if r_ent >= loIdx: exps[r_ent - loIdx] += beta_vals[cidx * perCkt + gidx] exps[r_ent - loIdx] %= Defs.prime fin_cval = self.com.gops.multiexp(nd_cvals, exps) is_ok = self.check_eq_proof(prev_cval, fin_cval) if not is_ok: raise ValueError("Verification failed checking input mlext")
def run(self, inputs, muxbits=None): self.build_prover() self.prover_fresh = False assert Defs.prime == self.com.gops.q ###################### # 0. Run computation # ###################### assert self.prover is not None # generate any nondet inputs inputs = self.nondet_gen(inputs, muxbits) # set muxbits and dump into transcript if muxbits is not None: self.prover.set_muxbits(muxbits) self.fs.put(muxbits, True) # run AC, then dump inputs and outputs into the transcript invals = [] invals_nd = [] for ins in inputs: ins = list(ins) + [0] * (2**self.nInBits - len(ins)) if self.fs.ndb is not None: loIdx = (2**self.nInBits) - (2**(self.nInBits - self.fs.ndb)) if self.fs.rvend is not None and self.fs.rvstart is not None: ins[self.fs.rvstart:self.fs.rvend + 1] = [0] * (self.fs.rvend - self.fs.rvstart + 1) ins_nd = ins[loIdx:] ins[loIdx:] = [0] * (2**(self.nInBits - self.fs.ndb)) invals_nd.append(ins_nd) invals.extend(ins) # need to pad up to nCopies if we're not using an RDL if self.rdl is None: assert util.clog2(len(invals)) == self.nInBits + self.nCopyBits invals += [0] * (2**(self.nInBits + self.nCopyBits) - len(invals)) self.fs.put(invals, True) # commit to nondet inputs from prover nd_rvals = [] if self.fs.ndb is not None: loIdx = (2**self.nInBits) - (2**(self.nInBits - self.fs.ndb)) prefill = [0] * loIdx for nd in invals_nd: nd_rvals.extend(prefill + self.create_pok(nd)) if self.rdl is None: assert len(nd_rvals) == self.nCopies * (2**self.nInBits) nd_rvals += [0] * (2**(self.nInBits + self.nCopyBits) - len(nd_rvals)) else: assert len(nd_rvals) == 2**self.nInBits # now V sets r_values if necessary if self.fs.rvstart is not None and self.fs.rvend is not None: r_values = [ self.fs.rand_scalar() for _ in xrange(self.fs.rvstart, self.fs.rvend + 1) ] if self.rdl is None: assert len(inputs) == self.nCopies for inp in inputs: inp[self.fs.rvstart:self.fs.rvend + 1] = r_values else: assert len(inputs) == 1 inputs[0][self.fs.rvstart:self.fs.rvend + 1] = r_values if self.rdl is None: self.prover.set_inputs(inputs) else: assert len(inputs) == 1 rdl_inputs = [] nd_rvals_new = [] for r_ents in self.rdl: rdl_inputs.append([inputs[0][r_ent] for r_ent in r_ents]) nd_rvals_new.extend(nd_rvals[r_ent] for r_ent in r_ents) nd_rvals_new.extend( 0 for _ in xrange((2**self.nCktBits) - len(r_ents))) self.prover.set_inputs(rdl_inputs) nd_rvals = nd_rvals_new assert len(nd_rvals) == len(self.rdl) * 2**self.nCktBits # evaluate the AC and put the outputs in the transcript outvals = util.flatten(self.prover.ckt_outputs) nOutBits = util.clog2(len(self.in0vv[-1])) assert util.clog2(len(outvals)) == nOutBits + self.nCopyBits outvals += [0] * (2**(nOutBits + self.nCopyBits) - len(outvals)) self.fs.put(outvals, True) # generate random point in (z1, z2) \in F^{nOutBits + nCopyBits} z1 = [self.fs.rand_scalar() for _ in xrange(0, nOutBits)] z1_2 = None z2 = [self.fs.rand_scalar() for _ in xrange(0, self.nCopyBits)] if Defs.track_fArith: self.sc_a.did_rng(nOutBits + self.nCopyBits) # to start, we reconcile with mlext of input prev_rval = None muls = None # if the AC has only one layer, tell P to give us H(.) project_line = len(self.in0vv) == 1 self.prover.set_z(z1, z2, None, None, project_line) ########################################## # 1. Interact with prover for each layer # ########################################## for lay in xrange(0, len(self.in0vv)): nInBits = self.layInBits[lay] nOutBits = self.layOutBits[lay] w1 = [] w2 = [] w3 = [] if Defs.track_fArith: self.sc_a.did_rng(2 * nInBits + self.nCopyBits) ################### ### A. Sumcheck ### ################### for rd in xrange(0, 2 * nInBits + self.nCopyBits): # get output from prv and check against expected value outs = self.prover.get_outputs() # 1. commitments to each val in the transcript outs_rvals = self.create_pok(outs) # 2. prove equality of poly(0) + poly(1) to prev comm value (or out mlext) zp1_rval = (sum(outs_rvals) + outs_rvals[0]) % Defs.prime self.create_eq_proof(prev_rval, zp1_rval) if Defs.track_fArith: self.sc_a.did_add(len(outs_rvals)) # 3. compute new prev_rval and go to next round nrand = self.fs.rand_scalar() self.prover.next_round(nrand) # compute comm to eval of poly(.) that V will use prev_rval = util.horner_eval(outs_rvals, nrand, self.sc_a) if rd < self.nCopyBits: assert len(outs) == 4 w3.append(nrand) else: assert len(outs) == 3 if rd < self.nCopyBits + nInBits: w1.append(nrand) else: w2.append(nrand) ############################### ### B. Extend to next layer ### ############################### outs = self.prover.get_outputs() if project_line: assert len(outs) == 1 + nInBits assert lay == len(self.in0vv) - 1 # (1) commit to all values plus their sum # (2) figure out c2val, r2val from above and outs[0] com # (3) create prod com # (4) send PoK of product for outs[0], c2val, prod (outs_rvals, pr_rvals) = self.create_final_prod_pok(outs) else: # just need to do product PoK since we're sending tV(r1) and tV(r2) assert len(outs) == 2 pr_rvals = self.create_prod_pok(outs) # prove final value in mlext eval # need mlext evals to do PoK (mlext_evals, mlx_z2) = self.eval_mlext(lay, z1, z2, w1, w2, w3, z1_2, muls) # mul gate is special, rest are OK tV_rval = 0 for (idx, elm) in enumerate(mlext_evals): tV_rval += elm * GateFunctionsPC[idx](pr_rvals[0], pr_rvals[1], pr_rvals[2], self.tV_a) tV_rval %= Defs.prime tV_rval *= mlx_z2 tV_rval %= Defs.prime self.create_eq_proof(prev_rval, tV_rval) if Defs.track_fArith: self.tV_a.did_add(len(mlext_evals) - 1) self.tV_a.did_mul(len(mlext_evals) + 1) project_next = lay == len(self.in0vv) - 2 if project_line: tau = self.fs.rand_scalar() muls = None prev_rval = util.horner_eval(outs_rvals, tau) z1 = [(elm1 + (elm2 - elm1) * tau) % Defs.prime for (elm1, elm2) in izip(w1, w2)] z1_2 = None if Defs.track_fArith: self.nlay_a.did_sub(len(w1)) self.nlay_a.did_mul(len(w1)) self.nlay_a.did_add(len(w1)) self.sc_a.did_rng() else: muls = [self.fs.rand_scalar(), self.fs.rand_scalar()] self.prover.next_layer(muls, project_next) tau = None prev_rval = (muls[0] * pr_rvals[0] + muls[1] * pr_rvals[1]) % Defs.prime z1 = w1 z1_2 = w2 if Defs.track_fArith: self.nlay_a.did_add() self.nlay_a.did_mul(2) self.sc_a.did_rng(2) project_line = project_next z2 = w3 ############################# # 2. Proof of eq with input # ############################# if nd_rvals: rval_mlext_eval = VerifierIOMLExt(z1 + z2, self.in_a).compute(nd_rvals) self.create_eq_proof(prev_rval, rval_mlext_eval) assert sum(val1 * val2 for (val1, val2) in izip(nd_rvals, invals)) == 0 else: self.create_eq_proof(None, prev_rval) ######################## # 3. Return transcript # ######################## return self.fs.to_string()