Ejemplo n.º 1
0
def run_one_test(nInBits, nCopies):
    nOutBits = nInBits

    (in0v, in1v, typv) = randutil.rand_ckt(nOutBits, nInBits)
    inputs = randutil.rand_inputs(nInBits, nCopies)

    circuit = CircuitProver(nCopies, 2**nInBits, [in0v], [in1v], [typv])
    circuit.set_inputs(inputs)

    z1 = [Defs.gen_random() for _ in xrange(0, nOutBits)]
    z2 = [Defs.gen_random() for _ in xrange(0, circuit.nCopyBits)]

    circuit.set_z(z1, z2, None, None, True)

    # mlExt of outputs
    outflat = util.flatten(circuit.ckt_outputs)
    inLayer_mults = LayerComputeBeta(nOutBits + circuit.nCopyBits, z1 + z2)
    assert len(outflat) == len(inLayer_mults.outputs)
    inLayermul = util.mul_vecs(inLayer_mults.outputs, outflat)
    inLayerExt = sum(inLayermul) % Defs.prime

    w1 = [Defs.gen_random() for _ in xrange(0, nInBits)]
    w2 = [Defs.gen_random() for _ in xrange(0, nInBits)]
    w3 = [Defs.gen_random() for _ in xrange(0, circuit.nCopyBits)]

    initOutputs = circuit.get_outputs()

    assert inLayerExt == (initOutputs[0] + sum(initOutputs)) % Defs.prime

    for i in xrange(0, len(w3)):
        circuit.next_round(w3[i])
        circuit.get_outputs()

    for i in xrange(0, len(w1)):
        circuit.next_round(w1[i])
        circuit.get_outputs()

    for i in xrange(0, len(w2)):
        circuit.next_round(w2[i])
        finalOutputs = circuit.get_outputs()

    # check the outputs by computing mlext of layer input directly

    inflat = util.flatten(inputs)

    v1_mults = LayerComputeBeta(circuit.layer_inbits[0] + circuit.nCopyBits,
                                w1 + w3)
    assert len(inflat) == len(v1_mults.outputs)
    v1inmul = util.mul_vecs(v1_mults.outputs, inflat)
    v1 = sum(v1inmul) % Defs.prime

    v2_mults = LayerComputeBeta(circuit.layer_inbits[0] + circuit.nCopyBits,
                                w2 + w3)
    assert len(inflat) == len(v2_mults.outputs)
    v2inmul = util.mul_vecs(v2_mults.outputs, inflat)
    v2 = sum(v2inmul) % Defs.prime

    assert v1 == finalOutputs[0]
    assert v2 == sum(finalOutputs) % Defs.prime
Ejemplo n.º 2
0
    def __init__(self, rdl, nInBits):
        # pylint: disable=protected-access
        self.rdl = util.flatten(rdl)
        self.nOutBits = util.clog2(len(self.rdl))
        self.nInBits = nInBits
        self.roundNum = 0

        self.muls = None

        circuit = type('', (object, ), {
            'comp_chi': None,
            'comp_v_fin': None,
            'comp_out': None
        })()
        self.circuit = circuit

        if Defs.track_fArith:
            fArith = Defs.fArith()
            circuit.comp_chi = fArith.new_cat("p_rdl_comp_chi_%d" % hash(self))
            circuit.comp_v_fin = fArith.new_cat("p_rdl_comp_v_fin_%d" %
                                                hash(self))
            circuit.comp_out = fArith.new_cat("p_rdl_comp_out_%d" % hash(self))
        else:
            circuit.comp_chi = None
            circuit.comp_v_fin = None
            circuit.comp_out = None

        self.inputs = []
        self.output = []

        # z1chi computation subckt
        # this uses V's fast beta evaluation (in set_z)
        self.compute_z1chi = None
        self.compute_z1chi_2 = None

        # everything is 2nd order, so we only need three eval points
        self.compute_v_final = LayerComputeV(nInBits, circuit.comp_v_fin)
        self.compute_v_final.set_other_factors([util.THIRD_EVAL_POINT])

        # pergate computation subckts for "early" rounds
        self.gates = []
        for (outNum, inNum) in enumerate(self.rdl):
            self.gates.append(
                gateprover.PassGateProver(False, inNum, 0, outNum, self, 0))
Ejemplo n.º 3
0
    def run(self, inputs, muxbits=None):
        self.build_prover()
        self.prover_fresh = False

        ######################
        # 0. Run computation #
        ######################
        assert self.prover is not None

        # set muxbits and dump into transcript
        if muxbits is not None:
            self.prover.set_muxbits(muxbits)
        self.fs.put(muxbits, True)

        # run AC, then dump inputs and outputs into the transcript
        self.prover.set_inputs(inputs)
        invals = []
        for ins in inputs:
            invals.extend(ins + [0] * (2**self.nInBits - len(ins)))
        assert util.clog2(len(invals)) == self.nInBits + self.nCopyBits
        invals += [0] * (2 ** (self.nInBits + self.nCopyBits) - len(invals))
        self.fs.put(invals, True)

        outvals = util.flatten(self.prover.ckt_outputs)
        nOutBits = util.clog2(len(self.in0vv[-1]))
        assert util.clog2(len(outvals)) == nOutBits + self.nCopyBits
        outvals += [0] * (2 ** (nOutBits + self.nCopyBits) - len(outvals))
        self.fs.put(outvals, True)

        # generate random point in (z1, z2) \in F^{nOutBits + nCopyBits}
        z1 = [ self.fs.rand_scalar() for _ in xrange(0, nOutBits) ]
        z2 = [ self.fs.rand_scalar() for _ in xrange(0, self.nCopyBits) ]
        if Defs.track_fArith:
            self.sc_a.did_rng(nOutBits + self.nCopyBits)

        # if the AC has only one layer, tell P to give us H(.)
        muls = None
        project_line = len(self.in0vv) == 1
        self.prover.set_z(z1, z2, None, None, project_line)

        ##########################################
        # 1. Interact with prover for each layer #
        ##########################################
        for lay in xrange(0, len(self.in0vv)):
            nInBits = self.layInBits[lay]
            nOutBits = self.layOutBits[lay]

            if Defs.track_fArith:
                self.sc_a.did_rng(2*nInBits + self.nCopyBits)

            ###################
            ### A. Sumcheck ###
            ###################
            for rd in xrange(0, 2 * nInBits + self.nCopyBits):
                # get output from prv and check against expected value
                outs = self.prover.get_outputs()
                self.fs.put(outs)

                if rd < self.nCopyBits:
                    assert len(outs) == 4
                else:
                    assert len(outs) == 3

                # go to next round
                self.prover.next_round(self.fs.rand_scalar())

            outs = self.prover.get_outputs()
            self.fs.put(outs)

            if project_line:
                assert len(outs) == 1 + nInBits
            else:
                assert len(outs) == 2

            ###############################
            ### B. Extend to next layer ###
            ###############################
            project_next = lay == len(self.in0vv) - 2

            if not project_line:
                muls = [self.fs.rand_scalar(), self.fs.rand_scalar()]
                self.prover.next_layer(muls, project_next)
                if Defs.track_fArith:
                    self.sc_a.did_rng(2)

            project_line = project_next

        ########################
        # 2. Return transcript #
        ########################
        return self.fs.to_string()
Ejemplo n.º 4
0
    def run(self, inputs, muxbits=None):
        self.build_prover()
        self.prover_fresh = False

        ############
        # 0. Setup #
        ############
        assert self.prover is not None

        # set muxbits
        self.muxbits = muxbits
        if muxbits is not None:
            self.prover.set_muxbits(muxbits)

        # set inputs and outputs
        self.prover.set_inputs(inputs)
        self.inputs = []
        for ins in inputs:
            self.inputs.extend(ins + [0] * (2**self.nInBits - len(ins)))

        # pad to power-of-2 number of copies
        assert util.clog2(len(self.inputs)) == self.nInBits + self.nCopyBits
        self.inputs += [0] * (2 ** (self.nInBits + self.nCopyBits) - len(self.inputs))

        ###############################################
        # 1. Compute multilinear extension of outputs #
        ###############################################
        self.outputs = util.flatten(self.prover.ckt_outputs)
        nOutBits = util.clog2(len(self.in0vv[-1]))
        assert util.clog2(len(self.outputs)) == nOutBits + self.nCopyBits

        # pad out to power-of-2 number of copies
        self.outputs += [0] * (2 ** (nOutBits + self.nCopyBits) - len(self.outputs))

        # generate random point in (z1, z2) \in F^{nOutBits + nCopyBits}
        z1 = [ Defs.gen_random() for _ in xrange(0, nOutBits) ]
        z1_2 = None
        z2 = [ Defs.gen_random() for _ in xrange(0, self.nCopyBits) ]
        if Defs.track_fArith:
            self.sc_a.did_rng(nOutBits + self.nCopyBits)

        # if the AC has only one layer, tell P to give us H(.)
        muls = None
        project_line = len(self.in0vv) == 1
        self.prover.set_z(z1, z2, None, None, project_line)

        # eval mlext of output at (z1,z2)
        expectNext = VerifierIOMLExt(z1 + z2, self.out_a).compute(self.outputs)

        ##########################################
        # 2. Interact with prover for each layer #
        ##########################################
        for lay in xrange(0, len(self.in0vv)):
            nInBits = self.layInBits[lay]
            nOutBits = self.layOutBits[lay]

            # random coins for this round
            w3 = [ Defs.gen_random() for _ in xrange(0, self.nCopyBits) ]
            w1 = [ Defs.gen_random() for _ in xrange(0, nInBits) ]
            w2 = [ Defs.gen_random() for _ in xrange(0, nInBits) ]
            if Defs.track_fArith:
                self.sc_a.did_rng(2*nInBits + self.nCopyBits)

            # convenience
            ws = w3 + w1 + w2

            ###################
            ### A. Sumcheck ###
            ###################
            for rd in xrange(0, 2 * nInBits + self.nCopyBits):
                # get output from prv and check against expected value
                outs = self.prover.get_outputs()
                gotVal = (outs[0] + sum(outs)) % Defs.prime
                if Defs.track_fArith:
                    self.sc_a.did_add(len(outs))

                assert expectNext == gotVal, "Verification failed in round %d of layer %d" % (rd, lay)

                # go to next round
                self.prover.next_round(ws[rd])
                expectNext = util.horner_eval(outs, ws[rd], self.sc_a)

            outs = self.prover.get_outputs()

            if project_line:
                v1 = outs[0] % Defs.prime
                v2 = sum(outs) % Defs.prime
                if Defs.track_fArith:
                    self.tV_a.did_add(len(outs)-1)
            else:
                assert len(outs) == 2
                v1 = outs[0]
                v2 = outs[1]

            ############################################
            ### B. Evaluate mlext of wiring predicates #
            ############################################
            tV_eval = self.eval_tV(lay, z1, z2, w1, w2, w3, v1, v2, z1_2 is not None, muls)

            # check that we got the correct value from the last round of the sumcheck
            assert expectNext == tV_eval, "Verification failed computing tV for layer %d" % lay

            ###############################
            ### C. Extend to next layer ###
            ###############################
            project_next = lay == len(self.in0vv) - 2

            if project_line:
                tau = Defs.gen_random()
                muls = None
                expectNext = util.horner_eval(outs, tau, self.nlay_a)
                # z1 = w1 + ( w2 - w1 ) * tau
                z1 = [ (elm1 + (elm2 - elm1) * tau) % Defs.prime for (elm1, elm2) in izip(w1, w2) ]
                z1_2 = None
                if Defs.track_fArith:
                    self.nlay_a.did_sub(len(w1))
                    self.nlay_a.did_mul(len(w1))
                    self.nlay_a.did_add(len(w1))
                    self.sc_a.did_rng()
            else:
                muls = (Defs.gen_random(), Defs.gen_random())
                tau = None
                expectNext = ( muls[0] * v1 + muls[1] * v2 ) % Defs.prime
                self.prover.next_layer(muls, project_next)
                z1 = w1
                z1_2 = w2
                if Defs.track_fArith:
                    self.nlay_a.did_add()
                    self.nlay_a.did_mul(2)
                    self.sc_a.did_rng(2)

            project_line = project_next
            z2 = w3

        ##############################################
        # 3. Compute multilinear extension of inputs #
        ##############################################
        # Finally, evaluate mlext of input at z1, z2
        input_mlext = VerifierIOMLExt(z1 + z2, self.in_a)
        input_mlext_eval = input_mlext.compute(self.inputs)

        assert input_mlext_eval == expectNext, "Verification failed checking input mlext"
Ejemplo n.º 5
0
def run_one_test(nInBits, nCopies):
    nOutBits = nInBits

    circuit = _DummyCircuitProver(nCopies)

    (in0v, in1v, typv) = randutil.rand_ckt(nOutBits, nInBits)
    typc = [tc.cgate for tc in typv]
    inputs = randutil.rand_inputs(nInBits, nCopies)

    # compute outputs
    ckt = ArithCircuit()
    inCktLayer = ArithCircuitInputLayer(ckt, nOutBits)
    outCktLayer = ArithCircuitLayer(ckt, inCktLayer, in0v, in1v, typc)
    ckt.layers = [inCktLayer, outCktLayer]
    outputs = []
    for inp in inputs:
        ckt.run(inp)
        outputs.append(ckt.outputs)

    z1 = [Defs.gen_random() for _ in xrange(0, nOutBits)]
    z2 = [Defs.gen_random() for _ in xrange(0, circuit.nCopyBits)]

    outLayer = LayerProver(nInBits, circuit, in0v, in1v, typv)
    outLayer.set_inputs(inputs)
    outLayer.set_z(z1, z2, None, None, True)

    # mlExt of outputs
    outflat = util.flatten(outputs)
    inLayer_mults = LayerComputeBeta(nOutBits + outLayer.circuit.nCopyBits,
                                     z1 + z2)
    assert len(outflat) == len(inLayer_mults.outputs)
    inLayermul = util.mul_vecs(inLayer_mults.outputs, outflat)
    inLayerExt = sum(inLayermul) % Defs.prime

    w3 = [Defs.gen_random() for _ in xrange(0, circuit.nCopyBits)]
    w1 = [Defs.gen_random() for _ in xrange(0, nInBits)]
    w2 = [Defs.gen_random() for _ in xrange(0, nInBits)]

    outLayer.compute_outputs()
    initOutputs = outLayer.output

    assert inLayerExt == (initOutputs[0] + sum(initOutputs)) % Defs.prime

    for i in xrange(0, len(w3)):
        outLayer.next_round(w3[i])
        outLayer.compute_outputs()

    for i in xrange(0, len(w1)):
        outLayer.next_round(w1[i])
        outLayer.compute_outputs()

    for i in xrange(0, len(w2)):
        outLayer.next_round(w2[i])
        outLayer.compute_outputs()

    finalOutputs = outLayer.output

    # check the outputs by computing mlext of layer input directly

    inflat = util.flatten(inputs)

    v1_mults = LayerComputeBeta(outLayer.nInBits + outLayer.circuit.nCopyBits,
                                w1 + w3)
    assert len(inflat) == len(v1_mults.outputs)
    v1inmul = util.mul_vecs(v1_mults.outputs, inflat)
    v1 = sum(v1inmul) % Defs.prime

    v2_mults = LayerComputeBeta(outLayer.nInBits + outLayer.circuit.nCopyBits,
                                w2 + w3)
    assert len(inflat) == len(v2_mults.outputs)
    v2inmul = util.mul_vecs(v2_mults.outputs, inflat)
    v2 = sum(v2inmul) % Defs.prime

    assert v1 == finalOutputs[0]
    assert v2 == sum(finalOutputs) % Defs.prime
Ejemplo n.º 6
0
    def run(self, inputs, muxbits=None):
        self.build_prover()
        self.build_wcom(True)
        self.prover_fresh = False

        assert Defs.prime == self.com.gops.q

        ######################
        # 0. Run computation #
        ######################
        assert self.prover is not None

        # generate any nondet inputs
        inputs = self.nondet_gen(inputs, muxbits)

        # set muxbits and dump into transcript
        if muxbits is not None:
            self.prover.set_muxbits(muxbits)
        self.fs.put(muxbits, True)

        # figure out the nondeterministic
        invals = []
        invals_nd = []
        for ins in inputs:
            ins = list(ins) + [0] * (2**self.nInBits - len(ins))
            if self.fs.ndb is not None:
                loIdx = (2 ** self.nInBits) - (2 ** (self.nInBits - self.fs.ndb))
                if self.fs.rvend is not None and self.fs.rvstart is not None:
                    ins[self.fs.rvstart:self.fs.rvend+1] = [0] * (self.fs.rvend - self.fs.rvstart + 1)
                ins_nd = ins[loIdx:]
                ins[loIdx:] = [0] * (2 ** (self.nInBits - self.fs.ndb))
                invals_nd.extend(ins_nd)
            invals.extend(ins)

        # need to pad up to nCopies if we're not using an RDL
        if self.rdl is None:
            assert util.clog2(len(invals)) == self.nInBits + self.nCopyBits
            invals += [0] * (2 ** (self.nInBits + self.nCopyBits) - len(invals))
        self.fs.put(invals, True)

        # commit to nondet inputs from prover
        if invals_nd:
            self.create_witness_comm(invals_nd)

        # now V sets r_values if necessary
        if self.fs.rvstart is not None and self.fs.rvend is not None:
            r_values = [ self.fs.rand_scalar() for _ in xrange(self.fs.rvstart, self.fs.rvend + 1) ]
            if self.rdl is None:
                assert len(inputs) == self.nCopies
                for inp in inputs:
                    inp[self.fs.rvstart:self.fs.rvend+1] = r_values
            else:
                assert len(inputs) == 1
                inputs[0][self.fs.rvstart:self.fs.rvend+1] = r_values

        if self.rdl is None:
            self.prover.set_inputs(inputs)
        else:
            rdl_inputs = []
            for r_ents in self.rdl:
                rdl_inputs.append([ inputs[0][r_ent] for r_ent in r_ents ])
            self.prover.set_inputs(rdl_inputs)

        # now evaluate the AC and put the outputs in the transcript
        outvals = util.flatten(self.prover.ckt_outputs)
        nOutBits = util.clog2(len(self.in0vv[-1]))
        assert util.clog2(len(outvals)) == nOutBits + self.nCopyBits
        outvals += [0] * (2 ** (nOutBits + self.nCopyBits) - len(outvals))
        self.fs.put(outvals, True)

        # generate random point in (z1, z2) \in F^{nOutBits + nCopyBits}
        z1 = [ self.fs.rand_scalar() for _ in xrange(0, nOutBits) ]
        z1_2 = None
        z2 = [ self.fs.rand_scalar() for _ in xrange(0, self.nCopyBits) ]
        if Defs.track_fArith:
            self.sc_a.did_rng(nOutBits + self.nCopyBits)

        # to start, we reconcile with mlext of output
        # V knows it, so computes g^{mlext}, i.e., Com(mlext; 0)
        prev_rval = 0
        muls = None
        # if the AC has only one layer, tell P to give us H(.)
        project_line = len(self.in0vv) == 1
        self.prover.set_z(z1, z2, None, None, project_line)

        ##########################################
        # 1. Interact with prover for each layer #
        ##########################################
        for lay in xrange(0, len(self.in0vv)):
            nInBits = self.layInBits[lay]
            nOutBits = self.layOutBits[lay]

            w1 = []
            w2 = []
            w3 = []
            self.com.reset()
            if Defs.track_fArith:
                self.sc_a.did_rng(2*nInBits + self.nCopyBits)

            ###################
            ### A. Sumcheck ###
            ###################
            for rd in xrange(0, 2 * nInBits + self.nCopyBits):
                # get output from prv and check against expected value
                outs = self.prover.get_outputs()

                # 1. commit to these values
                self.fs.put(self.com.compress(self.com.commitvec(outs)))

                # 2. compute new rand value and go to next round
                nrand = self.fs.rand_scalar()
                self.prover.next_round(nrand)

                if rd < self.nCopyBits:
                    assert len(outs) == 4
                    w3.append(nrand)
                else:
                    assert len(outs) == 3
                    if rd < self.nCopyBits + nInBits:
                        w1.append(nrand)
                    else:
                        w2.append(nrand)

            ###############################
            ### B. Extend to next layer ###
            ###############################
            outs = self.prover.get_outputs()

            if project_line:
                assert len(outs) == 1 + nInBits
                assert lay == len(self.in0vv) - 1
                # (1) commit to all values plus their sum
                # (2) figure out c2val, r2val from above and outs[0] com
                # (3) create prod com
                # (4) send PoK of product for outs[0], c2val, prod
                (outs_rvals, pr_rvals) = self.create_final_prod_pok(outs)
            else:
                # just need to do product PoK since we're sending tV(r1) and tV(r2)
                assert len(outs) == 2
                pr_rvals = self.create_prod_pok(outs)

            ### all claimed values are now in the transcript, so we can do vector proof
            # put delvals in the transcript
            self.fs.put([ self.com.compress(delval) for delval in self.com.vecpok_init() ])

            # now we need the vector of J values. first, generate the per-row Js
            j1val = self.fs.rand_scalar()
            jvals = [ self.fs.rand_scalar() for _ in xrange(0, 2 * nInBits + self.nCopyBits) ]

            # next, compute Jvec and put corresponding element in proof
            Jvec = _compute_Jvec(j1val, jvals, w3 + w1 + w2, self.nCopyBits, nInBits, False, self.com_q_a)
            self.fs.put(self.com.compress(self.com.vecpok_cont(Jvec)))

            # next, need mlext evals to do PoK
            (mlext_evals, mlx_z2) = self.eval_mlext(lay, z1, z2, w1, w2, w3, z1_2, muls)
            xyzvals = [0, 0, 0, 0]
            for (idx, elm) in enumerate(mlext_evals):
                GateFunctionsPVC[idx](elm, jvals[-1], xyzvals, self.tV_a)
            xyzvals = [ (mlx_z2 * v) % Defs.prime for v in xyzvals ]

            # finally, run vecpok_finish to put zvals in transcript
            chal = self.fs.rand_scalar()
            self.fs.put(self.com.vecpok_finish(j1val, prev_rval, xyzvals, pr_rvals, chal))

            if Defs.track_fArith:
                self.com_q_a.did_rng(2*nInBits + self.nCopyBits + 1)
                self.tV_a.did_mul(len(xyzvals))
                self.com_q_a.did_rng()

            project_next = (lay == len(self.in0vv) - 2) and (self.rdl is None)
            if project_line:
                tau = self.fs.rand_scalar()
                muls = None
                prev_rval = util.horner_eval(outs_rvals, tau)
                z1 = [ (elm1 + (elm2 - elm1) * tau) % Defs.prime for (elm1, elm2) in izip(w1, w2) ]
                z1_2 = None
                if Defs.track_fArith:
                    self.nlay_a.did_sub(len(w1))
                    self.nlay_a.did_mul(len(w1))
                    self.nlay_a.did_add(len(w1))
                    self.sc_a.did_rng()
            else:
                muls = [self.fs.rand_scalar(), self.fs.rand_scalar()]
                tau = None
                prev_rval = (muls[0] * pr_rvals[0] + muls[1] * pr_rvals[1]) % Defs.prime
                z1 = w1
                z1_2 = w2
                if Defs.track_fArith:
                    self.nlay_a.did_add()
                    self.nlay_a.did_mul(2)
                    self.sc_a.did_rng(2)

                if lay < len(self.in0vv) - 1:
                    self.prover.next_layer(muls, project_next)

            project_line = project_next
            z2 = w3

        self.prover = None  # don't need this anymore
        #############################
        # 1.5. Run RDL if necessary #
        #############################
        if self.rdl is not None:
            self.rdl_prover = RDLProver(self.rdl, self.nInBits)
            self.rdl_prover.set_inputs(inputs)
            self.rdl_prover.set_z(z1 + z2, z1_2 + z2, muls)

            w1 = []
            self.com.reset()
            if Defs.track_fArith:
                self.rdl_sc_a.did_rng(self.nInBits)

            ####################
            # Sumcheck for RDL #
            ####################
            for _ in xrange(0, self.nInBits):
                # get outputs
                outs = self.rdl_prover.compute_outputs()

                # commit to these values
                self.fs.put(self.com.compress(self.com.commitvec(outs)))

                # compute new value and go to next round
                nrand = self.fs.rand_scalar()
                w1.append(nrand)
                self.rdl_prover.next_round(nrand)

            #######################
            # Finish RDL sumcheck #
            #######################
            outs = self.rdl_prover.compute_outputs()
            self.rdl_prover = None      # don't need this any more; save the memory

            # in this case, output is just claimed eval of V_0
            assert len(outs) == 1
            pr_rvals = self.create_pok(outs)

            # all claimed values are now in the transcript, so we can do a vector proof
            self.fs.put([ self.com.compress(delval) for delval in self.com.vecpok_init() ])

            # now need vector of J values; generate per-row Js
            j1val = self.fs.rand_scalar()
            jvals = [ self.fs.rand_scalar() for _ in xrange(0, self.nInBits) ]

            # compute Jvec and put corresponding element in proof
            Jvec = _compute_Jvec(j1val, jvals, w1, 0, self.nInBits, True, self.com_q_a)
            self.fs.put(self.com.compress(self.com.vecpok_cont(Jvec)))

            # next, need mlext eval for PASS to do PoK
            mlext_pass = self.eval_mlext_pass(z1, z1_2, z2, w1, muls)
            xyzvals = [(mlext_pass * jvals[-1]) % Defs.prime]

            # run vecpok_finish to put zvals in transcript
            chal = self.fs.rand_scalar()
            self.fs.put(self.com.vecpok_finish(j1val, prev_rval, xyzvals, pr_rvals, chal))

            if Defs.track_fArith:
                self.com_q_a.did_rng(self.nInBits + 1)
                self.tP_a.did_mul()
                self.com_q_a.did_rng()

            # prepare variables for final check
            muls = None
            tau = None
            prev_rval = pr_rvals[0]
            z1 = w1
            z1_2 = None
            z2 = []

        #############################
        # 2. Proof of eq with input #
        #############################
        if invals_nd:
            # do witness proof
            r0val = reduce(lambda x, y: (x * y) % Defs.prime, z1[len(z1)-self.fs.ndb:], 1)
            rvals = z1[:len(z1)-self.fs.ndb] + z2

            self.create_witness_proof(rvals, r0val, prev_rval)

            if Defs.track_fArith:
                self.com_q_a.did_mul(self.fs.ndb)

        else:
            self.create_eq_proof(None, prev_rval)

        ########################
        # 3. Return transcript #
        ########################
        return self.fs.to_string()
Ejemplo n.º 7
0
    def run(self, inputs, muxbits=None):
        self.build_prover()
        self.prover_fresh = False

        assert Defs.prime == self.com.gops.q

        ######################
        # 0. Run computation #
        ######################
        assert self.prover is not None

        # generate any nondet inputs
        inputs = self.nondet_gen(inputs, muxbits)

        # set muxbits and dump into transcript
        if muxbits is not None:
            self.prover.set_muxbits(muxbits)
        self.fs.put(muxbits, True)

        # run AC, then dump inputs and outputs into the transcript
        invals = []
        invals_nd = []
        for ins in inputs:
            ins = list(ins) + [0] * (2**self.nInBits - len(ins))
            if self.fs.ndb is not None:
                loIdx = (2**self.nInBits) - (2**(self.nInBits - self.fs.ndb))
                if self.fs.rvend is not None and self.fs.rvstart is not None:
                    ins[self.fs.rvstart:self.fs.rvend +
                        1] = [0] * (self.fs.rvend - self.fs.rvstart + 1)
                ins_nd = ins[loIdx:]
                ins[loIdx:] = [0] * (2**(self.nInBits - self.fs.ndb))
                invals_nd.append(ins_nd)
            invals.extend(ins)

        # need to pad up to nCopies if we're not using an RDL
        if self.rdl is None:
            assert util.clog2(len(invals)) == self.nInBits + self.nCopyBits
            invals += [0] * (2**(self.nInBits + self.nCopyBits) - len(invals))
        self.fs.put(invals, True)

        # commit to nondet inputs from prover
        nd_rvals = []
        if self.fs.ndb is not None:
            loIdx = (2**self.nInBits) - (2**(self.nInBits - self.fs.ndb))
            prefill = [0] * loIdx
            for nd in invals_nd:
                nd_rvals.extend(prefill + self.create_pok(nd))
            if self.rdl is None:
                assert len(nd_rvals) == self.nCopies * (2**self.nInBits)
                nd_rvals += [0] * (2**(self.nInBits + self.nCopyBits) -
                                   len(nd_rvals))
            else:
                assert len(nd_rvals) == 2**self.nInBits

        # now V sets r_values if necessary
        if self.fs.rvstart is not None and self.fs.rvend is not None:
            r_values = [
                self.fs.rand_scalar()
                for _ in xrange(self.fs.rvstart, self.fs.rvend + 1)
            ]
            if self.rdl is None:
                assert len(inputs) == self.nCopies
                for inp in inputs:
                    inp[self.fs.rvstart:self.fs.rvend + 1] = r_values
            else:
                assert len(inputs) == 1
                inputs[0][self.fs.rvstart:self.fs.rvend + 1] = r_values

        if self.rdl is None:
            self.prover.set_inputs(inputs)
        else:
            assert len(inputs) == 1
            rdl_inputs = []
            nd_rvals_new = []
            for r_ents in self.rdl:
                rdl_inputs.append([inputs[0][r_ent] for r_ent in r_ents])
                nd_rvals_new.extend(nd_rvals[r_ent] for r_ent in r_ents)
                nd_rvals_new.extend(
                    0 for _ in xrange((2**self.nCktBits) - len(r_ents)))
            self.prover.set_inputs(rdl_inputs)
            nd_rvals = nd_rvals_new
            assert len(nd_rvals) == len(self.rdl) * 2**self.nCktBits

        # evaluate the AC and put the outputs in the transcript
        outvals = util.flatten(self.prover.ckt_outputs)
        nOutBits = util.clog2(len(self.in0vv[-1]))
        assert util.clog2(len(outvals)) == nOutBits + self.nCopyBits
        outvals += [0] * (2**(nOutBits + self.nCopyBits) - len(outvals))
        self.fs.put(outvals, True)

        # generate random point in (z1, z2) \in F^{nOutBits + nCopyBits}
        z1 = [self.fs.rand_scalar() for _ in xrange(0, nOutBits)]
        z1_2 = None
        z2 = [self.fs.rand_scalar() for _ in xrange(0, self.nCopyBits)]
        if Defs.track_fArith:
            self.sc_a.did_rng(nOutBits + self.nCopyBits)

        # to start, we reconcile with mlext of input
        prev_rval = None
        muls = None
        # if the AC has only one layer, tell P to give us H(.)
        project_line = len(self.in0vv) == 1
        self.prover.set_z(z1, z2, None, None, project_line)

        ##########################################
        # 1. Interact with prover for each layer #
        ##########################################
        for lay in xrange(0, len(self.in0vv)):
            nInBits = self.layInBits[lay]
            nOutBits = self.layOutBits[lay]

            w1 = []
            w2 = []
            w3 = []
            if Defs.track_fArith:
                self.sc_a.did_rng(2 * nInBits + self.nCopyBits)

            ###################
            ### A. Sumcheck ###
            ###################
            for rd in xrange(0, 2 * nInBits + self.nCopyBits):
                # get output from prv and check against expected value
                outs = self.prover.get_outputs()

                # 1. commitments to each val in the transcript
                outs_rvals = self.create_pok(outs)

                # 2. prove equality of poly(0) + poly(1) to prev comm value (or out mlext)
                zp1_rval = (sum(outs_rvals) + outs_rvals[0]) % Defs.prime
                self.create_eq_proof(prev_rval, zp1_rval)
                if Defs.track_fArith:
                    self.sc_a.did_add(len(outs_rvals))

                # 3. compute new prev_rval and go to next round
                nrand = self.fs.rand_scalar()
                self.prover.next_round(nrand)
                # compute comm to eval of poly(.) that V will use
                prev_rval = util.horner_eval(outs_rvals, nrand, self.sc_a)

                if rd < self.nCopyBits:
                    assert len(outs) == 4
                    w3.append(nrand)
                else:
                    assert len(outs) == 3
                    if rd < self.nCopyBits + nInBits:
                        w1.append(nrand)
                    else:
                        w2.append(nrand)

            ###############################
            ### B. Extend to next layer ###
            ###############################
            outs = self.prover.get_outputs()

            if project_line:
                assert len(outs) == 1 + nInBits
                assert lay == len(self.in0vv) - 1
                # (1) commit to all values plus their sum
                # (2) figure out c2val, r2val from above and outs[0] com
                # (3) create prod com
                # (4) send PoK of product for outs[0], c2val, prod
                (outs_rvals, pr_rvals) = self.create_final_prod_pok(outs)
            else:
                # just need to do product PoK since we're sending tV(r1) and tV(r2)
                assert len(outs) == 2
                pr_rvals = self.create_prod_pok(outs)

            # prove final value in mlext eval
            # need mlext evals to do PoK
            (mlext_evals, mlx_z2) = self.eval_mlext(lay, z1, z2, w1, w2, w3,
                                                    z1_2, muls)
            # mul gate is special, rest are OK
            tV_rval = 0
            for (idx, elm) in enumerate(mlext_evals):
                tV_rval += elm * GateFunctionsPC[idx](pr_rvals[0], pr_rvals[1],
                                                      pr_rvals[2], self.tV_a)
                tV_rval %= Defs.prime
            tV_rval *= mlx_z2
            tV_rval %= Defs.prime
            self.create_eq_proof(prev_rval, tV_rval)
            if Defs.track_fArith:
                self.tV_a.did_add(len(mlext_evals) - 1)
                self.tV_a.did_mul(len(mlext_evals) + 1)

            project_next = lay == len(self.in0vv) - 2
            if project_line:
                tau = self.fs.rand_scalar()
                muls = None
                prev_rval = util.horner_eval(outs_rvals, tau)
                z1 = [(elm1 + (elm2 - elm1) * tau) % Defs.prime
                      for (elm1, elm2) in izip(w1, w2)]
                z1_2 = None
                if Defs.track_fArith:
                    self.nlay_a.did_sub(len(w1))
                    self.nlay_a.did_mul(len(w1))
                    self.nlay_a.did_add(len(w1))
                    self.sc_a.did_rng()
            else:
                muls = [self.fs.rand_scalar(), self.fs.rand_scalar()]
                self.prover.next_layer(muls, project_next)
                tau = None
                prev_rval = (muls[0] * pr_rvals[0] +
                             muls[1] * pr_rvals[1]) % Defs.prime
                z1 = w1
                z1_2 = w2
                if Defs.track_fArith:
                    self.nlay_a.did_add()
                    self.nlay_a.did_mul(2)
                    self.sc_a.did_rng(2)

            project_line = project_next
            z2 = w3

        #############################
        # 2. Proof of eq with input #
        #############################
        if nd_rvals:
            rval_mlext_eval = VerifierIOMLExt(z1 + z2,
                                              self.in_a).compute(nd_rvals)
            self.create_eq_proof(prev_rval, rval_mlext_eval)
            assert sum(val1 * val2
                       for (val1, val2) in izip(nd_rvals, invals)) == 0

        else:
            self.create_eq_proof(None, prev_rval)

        ########################
        # 3. Return transcript #
        ########################
        return self.fs.to_string()