예제 #1
0
    def set_z(self, z1, z1_2, muls):
        assert len(
            muls) == 2, "Got muls of wrong size with non-None z1_2 in set_z()"
        self.roundNum = 0
        self.compute_z1chi = VerifierIOMLExt.compute_beta(
            z1, self.circuit.comp_chi, muls[0])
        self.compute_z1chi_2 = VerifierIOMLExt.compute_beta(
            z1_2, self.circuit.comp_chi, muls[1])

        # loop over all the gates and make them update their z coeffs
        for g in self.gates:
            g.set_z()
예제 #2
0
파일: commit.py 프로젝트: matteocam/fennel
    def set_rvals(self, rvals, r0val):
        self.r0val = r0val

        if self.nbits is not None:
            assert len(rvals) == self.nbits
        else:
            self.nbits = len(rvals)
            self.v1bits = self.nbits // 2
            self.v2bits = self.nbits - self.v1bits

        self.v1vals = VerifierIOMLExt.compute_beta(rvals[:self.v1bits],
                                                   self.com.rec_q)
        self.v2vals = VerifierIOMLExt.compute_beta(rvals[self.v1bits:],
                                                   self.com.rec_q)
예제 #3
0
파일: commit.py 프로젝트: matteocam/fennel
 def set_rvals_p(self, rvals, r0val, rZval):
     assert self.nbits == len(rvals)
     if self.v1bits > 0:
         mvals = VerifierIOMLExt.compute_beta(rvals[:self.v1bits],
                                              self.com.rec_q)
         assert len(mvals) == len(self.tvals)
         assert len(mvals) == len(self.svals)
         self.avals = util.vector_times_matrix(self.tvals, mvals,
                                               self.com.rec_q)
         self.rAval = util.dot_product(self.svals, mvals, self.com.rec_q)
     else:
         self.avals = self.tvals[0]
         self.rAval = self.svals[0]
     self.bvals = VerifierIOMLExt.compute_beta(rvals[self.v1bits:],
                                               self.com.rec_q, r0val)
     self.rZval = rZval
예제 #4
0
def speed_test(num_tests):
    nBits = random.randint(3, 8)
    inputs = [ [ Defs.gen_random() for _ in xrange(0, nBits)  ] for _ in xrange(0, num_tests) ]

    lcb = LayerComputeBeta(nBits)
    lcb.other_factors = []
    runtime = time.time()
    for idx in xrange(0, num_tests):
        lcb.set_inputs(inputs[idx])
    runtime = time.time() - runtime

    runtime2 = time.time()
    for idx in xrange(0, num_tests):
        VerifierIOMLExt.compute_beta(inputs[idx])
    runtime2 = time.time() - runtime2

    print "nBits: %d\nLayerComputeBeta: %f\nVerifierIOMLExt: %f\n" % (nBits, runtime, runtime2)
예제 #5
0
def run_test():
    # pylint: disable=global-variable-undefined,redefined-outer-name
    tinputs = [Defs.gen_random() for _ in xrange(0, nOutBits)]
    taus = [Defs.gen_random() for _ in xrange(0, nOutBits)]
    lcv.set_inputs(tinputs)
    assert lcv.outputs == VerifierIOMLExt.compute_beta(tinputs)

    inputs = [
        util.chi(util.numToBin(x, nOutBits), tinputs)
        for x in xrange(0, 2**nOutBits)
    ]

    global scratch
    global outputs

    scratch = list(inputs)
    outputs = list(inputs)

    def compute_next_value(tau):
        global scratch
        global outputs

        nscratch = []
        tauInv = (1 - tau) % Defs.prime

        for i in xrange(0, len(scratch) / 2):
            val = ((scratch[2 * i] * tauInv) +
                   (scratch[2 * i + 1] * tau)) % Defs.prime
            nscratch.append(val)

        del val
        scratch = nscratch

        #ndups = len(outputs) / len(scratch)
        #nouts = [ [val] * ndups for val in scratch ]
        outputs = scratch
        #outputs = [item for sublist in nouts for item in sublist]

    for i in xrange(0, nOutBits):
        assert lcv.inputs == inputs
        assert lcv.outputs == outputs
        assert lcv.scratch == scratch

        compute_next_value(taus[i])
        lcv.next_round(taus[i])

        assert outputs == lcv.outputs
        assert scratch == lcv.scratch

    assert lcv.prevPassValue == scratch[0]
    assert all([lcv.prevPassValue == elm[0] for elm in lcv.outputs_fact])
예제 #6
0
파일: commit.py 프로젝트: matteocam/fennel
    def set_rvals_v(self, rvals, r0val, Avals, Zval, vxeval):
        self.nbits = len(rvals)
        if self.bitdiv == 0:
            self.v1bits = 0
        else:
            self.v1bits = int(self.nbits / self.bitdiv)
        self.v2bits = self.nbits - self.v1bits

        self.rvals = rvals[self.v1bits:]
        self.r0val = r0val
        if self.v1bits == 0:
            Pval = Avals[0]
        else:
            mvals = VerifierIOMLExt.compute_beta(rvals[:self.v1bits],
                                                 self.com.rec_q)
            assert len(Avals) == len(mvals)
            Pval = self.gops.multiexp(Avals, mvals)
            if self.com.rec:
                self.com.rec_p.did_mexp(len(mvals))

        self.Pvals = [self.gops.mul(Pval, self.gops.maul(Zval, vxeval))]
        self.cvals = []
예제 #7
0
def speed_test(num_tests):
    nbits = random.randint(3, 8)
    taus = [Defs.gen_random() for _ in xrange(0, nbits)]
    inputs = [[Defs.gen_random() for _ in xrange(0, 2**nbits)]
              for _ in xrange(0, num_tests)]

    vim = VerifierIOMLExt(taus)
    runtime = time.time()
    for idx in xrange(0, num_tests):
        vim.compute_nosavebits(inputs[idx])
    runtime = time.time() - runtime

    runtime2 = time.time()
    for idx in xrange(0, num_tests):
        vim.compute_savebits(inputs[idx])
    runtime2 = time.time() - runtime2

    runtime3 = time.time()
    for idx in xrange(0, num_tests):
        vim.compute_sqrtbits(inputs[idx])
    runtime3 = time.time() - runtime3

    print "nBits: %d\nnosavebits: %f\nsavebits: %f\nsqrtbits: %f\n" % (
        nbits, runtime, runtime2, runtime3)
예제 #8
0
def run_vecpoly_helper(com):
    wcom = commit.WitnessCommit(com)

    # generate a random vector and random point and compute mlext
    rlen = random.randint(4, 10)
    rvals = [com.gops.rand_scalar() for _ in xrange(0, rlen)]
    wvals = [com.gops.rand_scalar() for _ in xrange(0, 2**rlen)]
    zeta_val = VerifierIOMLExt(rvals).compute(wvals)
    szeta = com.gops.rand_scalar()
    zeta = com.gops.pow_gh(zeta_val, szeta)

    # commit to the witness
    cvals = wcom.witness_commit(wvals)
    wcom.set_rvals(rvals, 1)
    (aval, Cval) = wcom.eval_init()

    # V challenge
    chal = com.gops.rand_scalar()
    (zvals, zh, zc) = wcom.eval_finish(chal, szeta)

    # now V checks
    wcom2 = commit.WitnessCommit(com)
    wcom2.set_rvals(rvals, 1)
    assert wcom2.eval_check(cvals, aval, Cval, zvals, zh, zc, chal, zeta, 0)
예제 #9
0
def run_one_test(nbits, squawk, nbins, pattern):
    z = [Defs.gen_random() for _ in xrange(0, nbits)]

    inv = [Defs.gen_random() for _ in xrange(0, (2**nbits) - nbins)]
    if pattern is 0:
        inv += [0 for _ in xrange(0, nbins)]
    elif pattern is 1:
        inv += [1 for _ in xrange(0, nbins)]
    elif pattern == 2:
        inv += [(i % 2) for i in xrange(0, nbins)]
    elif pattern == 3:
        inv += [((i + 1) % 2) for i in xrange(0, nbins)]
    else:
        inv += [random.randint(0, 1) for _ in xrange(0, nbins)]

    assert len(inv) == (2**nbits)

    Defs.track_fArith = True
    fa = Defs.fArith()
    oldrec = fa.new_cat("old")
    newrec = fa.new_cat("new")
    nw2rec = fa.new_cat("nw2")

    oldbeta = LayerComputeBeta(nbits, z, oldrec)
    oldval = sum(util.mul_vecs(oldbeta.outputs, inv)) % Defs.prime
    oldrec.did_mul(len(inv))
    oldrec.did_add(len(inv) - 1)

    newcomp = VerifierIOMLExt(z, newrec)
    newval = newcomp.compute(inv)

    nw2comp = LayerComputeV(nbits, nw2rec)
    nw2comp.other_factors = []
    nw2comp.set_inputs(inv)
    for zz in z:
        nw2comp.next_round(zz)
    nw2val = nw2comp.prevPassValue

    assert oldval == newval, "error for inputs (new) %s : %s" % (str(z),
                                                                 str(inv))
    assert oldval == nw2val, "error for inputs (nw2) %s : %s" % (str(z),
                                                                 str(inv))

    if squawk:
        print
        print "nbits: %d" % nbits
        print "OLD: %s" % str(oldrec)
        print "NEW: %s" % str(newrec)
        print "NW2: %s" % str(nw2rec)

    betacomp = VerifierIOMLExt.compute_beta(z)
    beta_lo = random.randint(0, 2**nbits - 1)
    beta_hi = random.randint(beta_lo, 2**nbits - 1)
    betacomp2 = VerifierIOMLExt.compute_beta(z, None, 1, beta_lo, beta_hi)
    # make sure that the right range was generated, and correctly
    assert len(betacomp) == len(betacomp2)
    assert all([b is None for b in betacomp2[:beta_lo]])
    assert all([b is not None for b in betacomp2[beta_lo:beta_hi + 1]])
    assert all([b is None for b in betacomp2[beta_hi + 1:]])
    assert all([
        b2 == b if b2 is not None else True
        for (b, b2) in zip(betacomp, betacomp2)
    ])

    return newrec.get_counts()
예제 #10
0
    def run(self, pf, _=None): # pylint: disable=arguments-differ
        self.fs = fs.FiatShamir.from_string(pf)

        ####
        # 0. Get i/o
        ####
        self.muxbits = self.fs.take(True)
        self.inputs = self.fs.take(True)
        self.outputs = self.fs.take(True)

        ####
        # 1. mlext of outs
        ####
        nOutBits = util.clog2(len(self.in0vv[-1]))
        assert util.clog2(len(self.outputs)) == nOutBits + self.nCopyBits

        # z1 and z2 vals
        z1 = [ self.fs.rand_scalar() for _ in xrange(0, nOutBits) ]
        z1_2 = None
        z2 = [ self.fs.rand_scalar() for _ in xrange(0, self.nCopyBits) ]
        if Defs.track_fArith:
            self.sc_a.did_rng(nOutBits + self.nCopyBits)

        # instructions for P
        muls = None
        project_line = len(self.in0vv) == 1
        expectNext = VerifierIOMLExt(z1 + z2, self.out_a).compute(self.outputs)

        ####
        # 2. Simulate prover interactions
        ####
        for lay in xrange(0, len(self.in0vv)):
            nInBits = self.layInBits[lay]
            nOutBits = self.layOutBits[lay]

            w1 = []
            w2 = []
            w3 = []
            if Defs.track_fArith:
                self.sc_a.did_rng(2*nInBits + self.nCopyBits)

            ####
            # A. Sumcheck
            ####
            for rd in xrange(0, 2 * nInBits + self.nCopyBits):
                outs = self.fs.take()
                gotVal = (outs[0] + sum(outs)) % Defs.prime
                if Defs.track_fArith:
                    self.sc_a.did_add(len(outs))

                assert expectNext == gotVal, "Verification failed in round %d of layer %d" % (rd, lay)

                nrand = self.fs.rand_scalar()
                expectNext = util.horner_eval(outs, nrand, self.sc_a)
                if rd < self.nCopyBits:
                    assert len(outs) == 4
                    w3.append(nrand)
                else:
                    assert len(outs) == 3
                    if rd < self.nCopyBits + nInBits:
                        w1.append(nrand)
                    else:
                        w2.append(nrand)

            outs = self.fs.take()

            if project_line:
                assert len(outs) == 1 + nInBits
                v1 = outs[0] % Defs.prime
                v2 = sum(outs) % Defs.prime
                if Defs.track_fArith:
                    self.tV_a.did_add(len(outs)-1)
            else:
                assert len(outs) == 2
                v1 = outs[0]
                v2 = outs[1]

            ####
            # B. mlext of wiring predicate
            ####
            tV_eval = self.eval_tV(lay, z1, z2, w1, w2, w3, v1, v2, z1_2, muls)
            assert expectNext == tV_eval, "Verification failed computing tV for layer %d" % lay

            ####
            # C. Extend to next layer
            ####
            project_next = lay == len(self.in0vv) - 2

            if project_line:
                tau = self.fs.rand_scalar()
                muls = None
                expectNext = util.horner_eval(outs, tau, self.nlay_a)
                # z1 = w1 + ( w2 - w1 ) * tau
                z1 = [ (elm1 + (elm2 - elm1) * tau) % Defs.prime for (elm1, elm2) in izip(w1, w2) ]
                z1_2 = None
                if Defs.track_fArith:
                    self.nlay_a.did_sub(len(w1))
                    self.nlay_a.did_mul(len(w1))
                    self.nlay_a.did_add(len(w1))
                    self.sc_a.did_rng()
            else:
                muls = [self.fs.rand_scalar(), self.fs.rand_scalar()]
                tau = None
                expectNext = ( muls[0] * v1 + muls[1] * v2 ) % Defs.prime
                z1 = w1
                z1_2 = w2
                if Defs.track_fArith:
                    self.nlay_a.did_add()
                    self.nlay_a.did_mul(2)
                    self.sc_a.did_rng(2)

            project_line = project_next
            z2 = w3

        ####
        # 3. mlext of inputs
        ####
        input_mlext_eval = VerifierIOMLExt(z1 + z2, self.in_a).compute(self.inputs)

        assert input_mlext_eval == expectNext, "Verification failed checking input mlext"
예제 #11
0
파일: commit.py 프로젝트: matteocam/fennel
class WitnessLogCommitShort(_WCBase):
    # pylint: disable=super-init-not-called
    avals = None
    bvals = None
    cvals = None
    gvals = None
    rvals = None
    r0val = None
    rPval = None
    rLval = None
    rRval = None
    Pvals = None
    dval = None
    rdelta = None
    rbeta = None

    def __init__(self, com, bitdiv=0):
        self.com = com
        self.gops = com.gops
        if bitdiv < 2:
            self.bitdiv = 0
        else:
            self.bitdiv = bitdiv

    def set_rvals_p(self, rvals, r0val, rZval):
        assert self.nbits == len(rvals)
        if self.v1bits > 0:
            mvals = VerifierIOMLExt.compute_beta(rvals[:self.v1bits],
                                                 self.com.rec_q)
            assert len(mvals) == len(self.tvals)
            assert len(mvals) == len(self.svals)
            self.avals = util.vector_times_matrix(self.tvals, mvals,
                                                  self.com.rec_q)
            self.rPval = util.dot_product(self.svals, mvals, self.com.rec_q)
        else:
            self.avals = self.tvals[0]
            self.rPval = self.svals[0]

        self.bvals = VerifierIOMLExt.compute_beta(rvals[self.v1bits:],
                                                  self.com.rec_q, r0val)
        self.rPval += rZval

        if self.com.rec:
            self.com.rec_q.did_add()

    def set_rvals_v(self, rvals, r0val, Avals, Zval, vxeval):
        self.nbits = len(rvals)
        if self.bitdiv == 0:
            self.v1bits = 0
        else:
            self.v1bits = int(self.nbits / self.bitdiv)
        self.v2bits = self.nbits - self.v1bits

        self.rvals = rvals[self.v1bits:]
        self.r0val = r0val
        if self.v1bits == 0:
            Pval = Avals[0]
        else:
            mvals = VerifierIOMLExt.compute_beta(rvals[:self.v1bits],
                                                 self.com.rec_q)
            assert len(Avals) == len(mvals)
            Pval = self.gops.multiexp(Avals, mvals)
            if self.com.rec:
                self.com.rec_p.did_mexp(len(mvals))

        self.Pvals = [self.gops.mul(Pval, self.gops.maul(Zval, vxeval))]
        self.cvals = []

    def redc_init(self):
        assert self.rLval is None
        assert self.rRval is None
        assert len(self.avals) == len(self.bvals)

        nprime = len(self.avals) // 2
        self.rLval = self.gops.rand_scalar()
        self.rRval = self.gops.rand_scalar()

        cL = sum(
            util.mul_vecs(self.avals[:nprime], self.bvals[nprime:],
                          self.com.rec_q)) % self.gops.q
        cR = sum(
            util.mul_vecs(self.avals[nprime:], self.bvals[:nprime],
                          self.com.rec_q)) % self.gops.q

        # g2^a1, g1^a2
        if self.gvals is None:
            Lval = self.gops.maul(
                self.gops.pow_gih(self.avals[:nprime], self.rLval, nprime),
                self.gops.q - cL)
            Rval = self.gops.maul(
                self.gops.pow_gih(self.avals[nprime:], self.rRval, 0),
                self.gops.q - cR)
        else:
            Lval = self.gops.multiexp(
                self.gvals[nprime:] + [self.gops.g, self.gops.h],
                self.avals[:nprime] + [cL, self.rLval])
            Rval = self.gops.multiexp(
                self.gvals[:nprime] + [self.gops.g, self.gops.h],
                self.avals[nprime:] + [cR, self.rRval])

        if self.com.rec:
            self.com.rec_q.did_rng(2)
            self.com.rec_p.did_mexps([2 + nprime] * 2)

        return (Lval, Rval)

    def _collapse_vec(self, v, c, c2):
        nprime = len(v) // 2
        ret = []
        for (v1, v2) in izip(v[:nprime], v[nprime:]):
            ret.append((v1 * c + v2 * c2) % self.gops.q)

        assert len(ret) == nprime
        if self.com.rec:
            self.com.rec_q.did_mul(len(v))
            self.com.rec_q.did_add(nprime)

        return ret

    def _collapse_gvec(self, v, c, c2):
        ret = []
        if v is None:
            nprime = 2**(self.v2bits - 1)
            for idx in xrange(0, nprime):
                ret.append(self.gops.pow_gij(idx, idx + nprime, c, c2))

        else:
            nprime = len(v) // 2
            for (v1, v2) in izip(v[:nprime], v[nprime:]):
                ret.append(self.gops.multiexp([v1, v2], [c, c2]))

        assert len(ret) == nprime
        if self.com.rec:
            self.com.rec_p.did_mexps([2] * nprime)

        return ret

    def redc_cont_p(self, c):
        assert self.rLval is not None
        assert self.rRval is not None
        assert len(self.avals) == len(self.bvals)
        assert self.rPval is not None
        assert self.gvals is None or len(self.gvals) == len(self.avals)

        cm1 = util.invert_modp(c, self.gops.q, self.com.rec_q)

        # compute new avals and bvals
        self.avals = self._collapse_vec(self.avals, c, cm1)
        self.bvals = self._collapse_vec(self.bvals, cm1, c)

        # compute new gvals
        self.gvals = self._collapse_gvec(self.gvals, cm1, c)

        # compute new rAval and rZval
        self.rPval += self.rLval * c * c + self.rRval * cm1 * cm1
        self.rPval %= self.gops.q

        self.rLval = None
        self.rRval = None

        if self.com.rec:
            self.com.rec_q.did_inv()
            self.com.rec_q.did_mul(4)
            self.com.rec_q.did_add(2)

        return len(self.gvals) > 1

    def redc_cont_v(self, c, LRval):
        assert self.Pvals is not None
        assert self.bvals is None
        assert self.gvals is None

        # record c, Aval, and Zval
        self.cvals.append(c)
        self.Pvals.extend(LRval)

        return self.v2bits != len(self.cvals)

    def fin_init(self):
        assert len(self.gvals) == 1
        assert len(self.bvals) == 1
        assert len(self.avals) == 1

        self.dval = self.gops.rand_scalar()
        self.rdelta = self.gops.rand_scalar()
        self.rbeta = self.gops.rand_scalar()

        # delta and beta are g'^d and g^d, respectively
        delta = self.gops.multiexp([self.gvals[0], self.gops.h],
                                   [self.dval, self.rdelta])
        beta = self.gops.pow_gh(self.dval, self.rbeta)

        if self.com.rec:
            self.com.rec_q.did_rng(3)
            self.com.rec_p.did_mexps([2, 2])

        return (delta, beta)

    def fin_finish(self, c):
        z1val = (c * self.avals[0] * self.bvals[0] + self.dval) % self.gops.q
        z2val = ((c * self.rPval + self.rbeta) * self.bvals[0] +
                 self.rdelta) % self.gops.q

        if self.com.rec:
            self.com.rec_q.did_add(3)
            self.com.rec_q.did_mul(4)

        return (z1val, z2val)

    def fin_check(self, c, (delta, beta), (z1val, z2val)):
        # compute inverses
        cprod = reduce(lambda x, y: (x * y) % self.gops.q, self.cvals)
        cprodinv = util.invert_modp(cprod, self.gops.q, self.com.rec_q)
        cinvs = [0] * len(self.cvals)
        for idx in xrange(0, len(self.cvals)):
            cvs = chain(self.cvals[:idx], self.cvals[idx + 1:])
            cinvs[idx] = reduce(lambda x, y: (x * y) % self.gops.q, cvs,
                                cprodinv)

        csqs = [(cval * cval) % self.gops.q for cval in self.cvals]
        cinvsqs = [(cval * cval) % self.gops.q for cval in cinvs]
        # compute powers for multiexps
        gpows = [cprodinv]
        for cval in csqs:
            new = [0] * 2 * len(gpows)
            for (idx, gpow) in enumerate(gpows):
                new[2 * idx] = gpow
                new[2 * idx + 1] = (gpow * cval) % self.gops.q
            gpows = new

        # compute powers for P commitments
        bval = (VerifierIOMLExt(self.rvals, self.com.rec_q).compute(gpows) *
                self.r0val) % self.gops.q
        bc = (bval * c) % self.gops.q
        azpows = [bc] + [(bc * cval) % self.gops.q
                         for cval in chain.from_iterable(izip(csqs, cinvsqs))]

        # now compute the check values themselves
        gval = self.gops.pow_gi(gpows, 0)
        lhs = self.gops.multiexp(self.Pvals + [beta, delta],
                                 azpows + [bval, 1])
        rhs = self.gops.multiexp([gval, self.gops.g, self.gops.h],
                                 [z1val, (z1val * bval) % self.gops.q, z2val])

        if self.com.rec:
            clen = len(self.cvals)
            self.com.rec_p.did_mexps([3, 2 + len(self.Pvals), len(gpows)])
            self.com.rec_q.did_mul(
                len(gpows) + (clen + 1) * (clen - 1) + 4 * clen + 2)

        return lhs == rhs
예제 #12
0
파일: commit.py 프로젝트: matteocam/fennel
        # compute powers for multiexps
        azpows = [c] + [
            (c * cval) % self.gops.q
            for cval in chain.from_iterable(izip(cinvs, self.cvals))
        ]
        gpows = [(cprodinv * cprodinv) % self.gops.q]
        for cval in self.cvals:
            new = []
            for gpow in gpows:
                new.extend([(gpow * cval) % self.gops.q, gpow])
            gpows = new

        # compute bvals
        stopbits = util.clog2(self.stoplen)
        bvinit = (VerifierIOMLExt(self.rvals[stopbits:],
                                  self.com.rec_q).compute(gpows) *
                  self.r0val) % self.gops.q
        bvals = VerifierIOMLExt.compute_beta(self.rvals[:stopbits],
                                             self.com.rec_q, bvinit)

        # now compute the check values themselves
        gvals = [
            self.gops.pow_gi(gpows, idx, self.stoplen)
            for idx in xrange(0, self.stoplen)
        ]
        lhs1 = self.gops.multiexp(self.Avals + [delta], azpows + [1])
        rhs1 = self.gops.multiexp(gvals + [self.gops.h], zvals + [zdelta])

        prod_bz = sum(util.mul_vecs(bvals, zvals,
                                    self.com.rec_q)) % self.gops.q
        lhs2 = self.gops.multiexp(self.Zvals + [beta], azpows + [1])
예제 #13
0
    def run(self, pf, _=None):  # pylint: disable=arguments-differ
        assert Defs.prime == self.com.gops.q
        self.fs = fs.FiatShamir.from_string(pf)
        assert Defs.prime == self.fs.q

        ####
        # 0. Get i/o
        ####
        self.muxbits = self.fs.take(True)
        self.inputs = self.fs.take(True)

        # get witness commitments
        nd_cvals = []
        if self.fs.ndb is not None:
            num_vals = 2**(self.nInBits - self.fs.ndb)
            nCopies = 1
            if self.rdl is None:
                nCopies = self.nCopies
            for copy in xrange(0, nCopies):
                (cvals, is_ok) = self.check_pok(num_vals)
                if not is_ok:
                    raise ValueError(
                        "Failed getting commitments to input for copy %d" %
                        copy)
                if self.rdl is None:
                    nd_cvals.append(cvals)
                else:
                    nd_cvals.extend(cvals)

        # now generate rvals
        if self.fs.rvstart is not None and self.fs.rvend is not None:
            r_values = [
                self.fs.rand_scalar()
                for _ in xrange(self.fs.rvstart, self.fs.rvend + 1)
            ]
            nCopies = 1
            if self.rdl is None:
                nCopies = self.nCopies
            for idx in xrange(0, nCopies):
                first = idx * (2**self.nInBits) + self.fs.rvstart
                last = first + self.fs.rvend - self.fs.rvstart + 1
                self.inputs[first:last] = r_values

        # finally, get outputs
        self.outputs = self.fs.take(True)

        ####
        # 1. mlext of outs
        ####
        nOutBits = util.clog2(len(self.in0vv[-1]))
        assert util.clog2(len(self.outputs)) == nOutBits + self.nCopyBits

        # z1 and z2 vals
        z1 = [self.fs.rand_scalar() for _ in xrange(0, nOutBits)]
        z1_2 = None
        z2 = [self.fs.rand_scalar() for _ in xrange(0, self.nCopyBits)]
        if Defs.track_fArith:
            self.sc_a.did_rng(nOutBits + self.nCopyBits)

        # instructions for P
        muls = None
        project_line = len(self.in0vv) == 1
        expectNext = VerifierIOMLExt(z1 + z2, self.out_a).compute(self.outputs)
        prev_cval = None

        ####
        # 2. Simulate prover interactions
        ####
        for lay in xrange(0, len(self.in0vv)):
            nInBits = self.layInBits[lay]
            nOutBits = self.layOutBits[lay]

            w1 = []
            w2 = []
            w3 = []
            if Defs.track_fArith:
                self.sc_a.did_rng(2 * nInBits + self.nCopyBits)

            ####
            # A. Sumcheck
            ####
            for rd in xrange(0, 2 * nInBits + self.nCopyBits):
                if rd < self.nCopyBits:
                    nelms = 4
                else:
                    nelms = 3

                (cvals, is_ok) = self.check_pok(nelms)
                if not is_ok:
                    raise ValueError(
                        "PoK failed for commits in round %d of layer %d" %
                        (rd, lay))

                ncom = self.com.zero_plus_one_eval(cvals)
                if prev_cval is None:
                    is_ok = self.check_val_proof(ncom, expectNext)
                else:
                    is_ok = self.check_eq_proof(prev_cval, ncom)
                if not is_ok:
                    raise ValueError(
                        "Verification failed in round %d of layer %d" %
                        (rd, lay))

                nrand = self.fs.rand_scalar()
                prev_cval = self.com.horner_eval(cvals, nrand)

                if rd < self.nCopyBits:
                    w3.append(nrand)
                elif rd < self.nCopyBits + nInBits:
                    w1.append(nrand)
                else:
                    w2.append(nrand)

            ####
            # B. Extend to next layer
            ####
            if project_line:
                assert lay == len(self.in0vv) - 1
                (cvals, c2val, c3val,
                 is_ok) = self.check_final_prod_pok(nInBits)
                if not is_ok:
                    raise ValueError(
                        "Verification of final product PoK failed")
                pr_cvals = (cvals[0], c2val, c3val)
            else:
                (pr_cvals, is_ok) = self.check_prod_pok()
                if not is_ok:
                    raise ValueError(
                        "Verification of product PoK failed in layer %d" % lay)

            # check final val with mlext eval
            (mlext_evals, mlx_z2) = self.eval_mlext(lay, z1, z2, w1, w2, w3,
                                                    z1_2, muls)
            tV_cval = self.com.tV_eval(pr_cvals, mlext_evals, mlx_z2)
            is_ok = self.check_eq_proof(prev_cval, tV_cval)
            if not is_ok:
                raise ValueError(
                    "Verification of mlext eq proof failed in layer %d" % lay)

            project_next = lay == len(self.in0vv) - 2
            if project_line:
                tau = self.fs.rand_scalar()
                muls = None
                prev_cval = self.com.horner_eval(cvals, tau)
                z1 = [(elm1 + (elm2 - elm1) * tau) % Defs.prime
                      for (elm1, elm2) in izip(w1, w2)]
                z1_2 = None
                if Defs.track_fArith:
                    self.nlay_a.did_sub(len(w1))
                    self.nlay_a.did_mul(len(w1))
                    self.nlay_a.did_add(len(w1))
                    self.sc_a.did_rng()
            else:
                muls = [self.fs.rand_scalar(), self.fs.rand_scalar()]
                tau = None
                prev_cval = self.com.muls_eval(pr_cvals, muls)
                z1 = w1
                z1_2 = w2
                if Defs.track_fArith:
                    self.sc_a.did_rng(2)

            project_line = project_next
            z2 = w3

        ####
        # 3. mlext of inputs
        ####
        if self.rdl is None:
            fin_inputs = self.inputs
        else:
            fin_inputs = []
            for r_ents in self.rdl:
                fin_inputs.extend(self.inputs[r_ent] for r_ent in r_ents)
        input_mlext_eval = VerifierIOMLExt(z1 + z2,
                                           self.in_a).compute(fin_inputs)

        if len(nd_cvals) is 0 or self.fs.ndb is None:
            is_ok = self.check_val_proof(prev_cval, input_mlext_eval)
        elif self.rdl is None:
            copy_vals = VerifierIOMLExt.compute_beta(z2, self.in_a)
            loIdx = (2**self.nInBits) - (2**(self.nInBits - self.fs.ndb))
            hiIdx = (2**self.nInBits) - 1
            gate_vals = VerifierIOMLExt.compute_beta(z1, self.in_a, 1, loIdx,
                                                     hiIdx)
            num_nd = 2**(self.nInBits - self.fs.ndb)

            cval_acc = self.com.accum_init(input_mlext_eval)
            for (cidx, vals) in enumerate(nd_cvals):
                copy_mul = copy_vals[cidx]
                assert len(vals) == num_nd
                for (gidx, val) in enumerate(vals, start=loIdx):
                    exp = (copy_mul * gate_vals[gidx]) % Defs.prime
                    cval_acc = self.com.accum_add(cval_acc, val, exp)

                if Defs.track_fArith:
                    self.com_q_a.did_mul(len(vals))

            fin_cval = self.com.accum_finish(cval_acc)
            is_ok = self.check_eq_proof(prev_cval, fin_cval)
        else:
            beta_vals = VerifierIOMLExt.compute_beta(z1 + z2, self.in_a)
            loIdx = (2**self.nInBits) - (2**(self.nInBits - self.fs.ndb))
            perCkt = 2**self.nCktBits

            nd_cvals.append(self.com.gops.g)
            exps = [0] * len(nd_cvals)
            exps[-1] = input_mlext_eval

            for (cidx, r_ents) in enumerate(self.rdl):
                for (gidx, r_ent) in enumerate(r_ents):
                    if r_ent >= loIdx:
                        exps[r_ent - loIdx] += beta_vals[cidx * perCkt + gidx]
                        exps[r_ent - loIdx] %= Defs.prime

            fin_cval = self.com.gops.multiexp(nd_cvals, exps)
            is_ok = self.check_eq_proof(prev_cval, fin_cval)

        if not is_ok:
            raise ValueError("Verification failed checking input mlext")
예제 #14
0
    def run(self, inputs, muxbits=None):
        self.build_prover()
        self.prover_fresh = False

        assert Defs.prime == self.com.gops.q

        ######################
        # 0. Run computation #
        ######################
        assert self.prover is not None

        # generate any nondet inputs
        inputs = self.nondet_gen(inputs, muxbits)

        # set muxbits and dump into transcript
        if muxbits is not None:
            self.prover.set_muxbits(muxbits)
        self.fs.put(muxbits, True)

        # run AC, then dump inputs and outputs into the transcript
        invals = []
        invals_nd = []
        for ins in inputs:
            ins = list(ins) + [0] * (2**self.nInBits - len(ins))
            if self.fs.ndb is not None:
                loIdx = (2**self.nInBits) - (2**(self.nInBits - self.fs.ndb))
                if self.fs.rvend is not None and self.fs.rvstart is not None:
                    ins[self.fs.rvstart:self.fs.rvend +
                        1] = [0] * (self.fs.rvend - self.fs.rvstart + 1)
                ins_nd = ins[loIdx:]
                ins[loIdx:] = [0] * (2**(self.nInBits - self.fs.ndb))
                invals_nd.append(ins_nd)
            invals.extend(ins)

        # need to pad up to nCopies if we're not using an RDL
        if self.rdl is None:
            assert util.clog2(len(invals)) == self.nInBits + self.nCopyBits
            invals += [0] * (2**(self.nInBits + self.nCopyBits) - len(invals))
        self.fs.put(invals, True)

        # commit to nondet inputs from prover
        nd_rvals = []
        if self.fs.ndb is not None:
            loIdx = (2**self.nInBits) - (2**(self.nInBits - self.fs.ndb))
            prefill = [0] * loIdx
            for nd in invals_nd:
                nd_rvals.extend(prefill + self.create_pok(nd))
            if self.rdl is None:
                assert len(nd_rvals) == self.nCopies * (2**self.nInBits)
                nd_rvals += [0] * (2**(self.nInBits + self.nCopyBits) -
                                   len(nd_rvals))
            else:
                assert len(nd_rvals) == 2**self.nInBits

        # now V sets r_values if necessary
        if self.fs.rvstart is not None and self.fs.rvend is not None:
            r_values = [
                self.fs.rand_scalar()
                for _ in xrange(self.fs.rvstart, self.fs.rvend + 1)
            ]
            if self.rdl is None:
                assert len(inputs) == self.nCopies
                for inp in inputs:
                    inp[self.fs.rvstart:self.fs.rvend + 1] = r_values
            else:
                assert len(inputs) == 1
                inputs[0][self.fs.rvstart:self.fs.rvend + 1] = r_values

        if self.rdl is None:
            self.prover.set_inputs(inputs)
        else:
            assert len(inputs) == 1
            rdl_inputs = []
            nd_rvals_new = []
            for r_ents in self.rdl:
                rdl_inputs.append([inputs[0][r_ent] for r_ent in r_ents])
                nd_rvals_new.extend(nd_rvals[r_ent] for r_ent in r_ents)
                nd_rvals_new.extend(
                    0 for _ in xrange((2**self.nCktBits) - len(r_ents)))
            self.prover.set_inputs(rdl_inputs)
            nd_rvals = nd_rvals_new
            assert len(nd_rvals) == len(self.rdl) * 2**self.nCktBits

        # evaluate the AC and put the outputs in the transcript
        outvals = util.flatten(self.prover.ckt_outputs)
        nOutBits = util.clog2(len(self.in0vv[-1]))
        assert util.clog2(len(outvals)) == nOutBits + self.nCopyBits
        outvals += [0] * (2**(nOutBits + self.nCopyBits) - len(outvals))
        self.fs.put(outvals, True)

        # generate random point in (z1, z2) \in F^{nOutBits + nCopyBits}
        z1 = [self.fs.rand_scalar() for _ in xrange(0, nOutBits)]
        z1_2 = None
        z2 = [self.fs.rand_scalar() for _ in xrange(0, self.nCopyBits)]
        if Defs.track_fArith:
            self.sc_a.did_rng(nOutBits + self.nCopyBits)

        # to start, we reconcile with mlext of input
        prev_rval = None
        muls = None
        # if the AC has only one layer, tell P to give us H(.)
        project_line = len(self.in0vv) == 1
        self.prover.set_z(z1, z2, None, None, project_line)

        ##########################################
        # 1. Interact with prover for each layer #
        ##########################################
        for lay in xrange(0, len(self.in0vv)):
            nInBits = self.layInBits[lay]
            nOutBits = self.layOutBits[lay]

            w1 = []
            w2 = []
            w3 = []
            if Defs.track_fArith:
                self.sc_a.did_rng(2 * nInBits + self.nCopyBits)

            ###################
            ### A. Sumcheck ###
            ###################
            for rd in xrange(0, 2 * nInBits + self.nCopyBits):
                # get output from prv and check against expected value
                outs = self.prover.get_outputs()

                # 1. commitments to each val in the transcript
                outs_rvals = self.create_pok(outs)

                # 2. prove equality of poly(0) + poly(1) to prev comm value (or out mlext)
                zp1_rval = (sum(outs_rvals) + outs_rvals[0]) % Defs.prime
                self.create_eq_proof(prev_rval, zp1_rval)
                if Defs.track_fArith:
                    self.sc_a.did_add(len(outs_rvals))

                # 3. compute new prev_rval and go to next round
                nrand = self.fs.rand_scalar()
                self.prover.next_round(nrand)
                # compute comm to eval of poly(.) that V will use
                prev_rval = util.horner_eval(outs_rvals, nrand, self.sc_a)

                if rd < self.nCopyBits:
                    assert len(outs) == 4
                    w3.append(nrand)
                else:
                    assert len(outs) == 3
                    if rd < self.nCopyBits + nInBits:
                        w1.append(nrand)
                    else:
                        w2.append(nrand)

            ###############################
            ### B. Extend to next layer ###
            ###############################
            outs = self.prover.get_outputs()

            if project_line:
                assert len(outs) == 1 + nInBits
                assert lay == len(self.in0vv) - 1
                # (1) commit to all values plus their sum
                # (2) figure out c2val, r2val from above and outs[0] com
                # (3) create prod com
                # (4) send PoK of product for outs[0], c2val, prod
                (outs_rvals, pr_rvals) = self.create_final_prod_pok(outs)
            else:
                # just need to do product PoK since we're sending tV(r1) and tV(r2)
                assert len(outs) == 2
                pr_rvals = self.create_prod_pok(outs)

            # prove final value in mlext eval
            # need mlext evals to do PoK
            (mlext_evals, mlx_z2) = self.eval_mlext(lay, z1, z2, w1, w2, w3,
                                                    z1_2, muls)
            # mul gate is special, rest are OK
            tV_rval = 0
            for (idx, elm) in enumerate(mlext_evals):
                tV_rval += elm * GateFunctionsPC[idx](pr_rvals[0], pr_rvals[1],
                                                      pr_rvals[2], self.tV_a)
                tV_rval %= Defs.prime
            tV_rval *= mlx_z2
            tV_rval %= Defs.prime
            self.create_eq_proof(prev_rval, tV_rval)
            if Defs.track_fArith:
                self.tV_a.did_add(len(mlext_evals) - 1)
                self.tV_a.did_mul(len(mlext_evals) + 1)

            project_next = lay == len(self.in0vv) - 2
            if project_line:
                tau = self.fs.rand_scalar()
                muls = None
                prev_rval = util.horner_eval(outs_rvals, tau)
                z1 = [(elm1 + (elm2 - elm1) * tau) % Defs.prime
                      for (elm1, elm2) in izip(w1, w2)]
                z1_2 = None
                if Defs.track_fArith:
                    self.nlay_a.did_sub(len(w1))
                    self.nlay_a.did_mul(len(w1))
                    self.nlay_a.did_add(len(w1))
                    self.sc_a.did_rng()
            else:
                muls = [self.fs.rand_scalar(), self.fs.rand_scalar()]
                self.prover.next_layer(muls, project_next)
                tau = None
                prev_rval = (muls[0] * pr_rvals[0] +
                             muls[1] * pr_rvals[1]) % Defs.prime
                z1 = w1
                z1_2 = w2
                if Defs.track_fArith:
                    self.nlay_a.did_add()
                    self.nlay_a.did_mul(2)
                    self.sc_a.did_rng(2)

            project_line = project_next
            z2 = w3

        #############################
        # 2. Proof of eq with input #
        #############################
        if nd_rvals:
            rval_mlext_eval = VerifierIOMLExt(z1 + z2,
                                              self.in_a).compute(nd_rvals)
            self.create_eq_proof(prev_rval, rval_mlext_eval)
            assert sum(val1 * val2
                       for (val1, val2) in izip(nd_rvals, invals)) == 0

        else:
            self.create_eq_proof(None, prev_rval)

        ########################
        # 3. Return transcript #
        ########################
        return self.fs.to_string()