Esempio n. 1
0
def run_one_test(nInBits, nCopies):
    nOutBits = nInBits

    (in0v, in1v, typv) = randutil.rand_ckt(nOutBits, nInBits)
    inputs = randutil.rand_inputs(nInBits, nCopies)

    circuit = CircuitProver(nCopies, 2**nInBits, [in0v], [in1v], [typv])
    circuit.set_inputs(inputs)

    z1 = [Defs.gen_random() for _ in xrange(0, nOutBits)]
    z2 = [Defs.gen_random() for _ in xrange(0, circuit.nCopyBits)]

    circuit.set_z(z1, z2, None, None, True)

    # mlExt of outputs
    outflat = util.flatten(circuit.ckt_outputs)
    inLayer_mults = LayerComputeBeta(nOutBits + circuit.nCopyBits, z1 + z2)
    assert len(outflat) == len(inLayer_mults.outputs)
    inLayermul = util.mul_vecs(inLayer_mults.outputs, outflat)
    inLayerExt = sum(inLayermul) % Defs.prime

    w1 = [Defs.gen_random() for _ in xrange(0, nInBits)]
    w2 = [Defs.gen_random() for _ in xrange(0, nInBits)]
    w3 = [Defs.gen_random() for _ in xrange(0, circuit.nCopyBits)]

    initOutputs = circuit.get_outputs()

    assert inLayerExt == (initOutputs[0] + sum(initOutputs)) % Defs.prime

    for i in xrange(0, len(w3)):
        circuit.next_round(w3[i])
        circuit.get_outputs()

    for i in xrange(0, len(w1)):
        circuit.next_round(w1[i])
        circuit.get_outputs()

    for i in xrange(0, len(w2)):
        circuit.next_round(w2[i])
        finalOutputs = circuit.get_outputs()

    # check the outputs by computing mlext of layer input directly

    inflat = util.flatten(inputs)

    v1_mults = LayerComputeBeta(circuit.layer_inbits[0] + circuit.nCopyBits,
                                w1 + w3)
    assert len(inflat) == len(v1_mults.outputs)
    v1inmul = util.mul_vecs(v1_mults.outputs, inflat)
    v1 = sum(v1inmul) % Defs.prime

    v2_mults = LayerComputeBeta(circuit.layer_inbits[0] + circuit.nCopyBits,
                                w2 + w3)
    assert len(inflat) == len(v2_mults.outputs)
    v2inmul = util.mul_vecs(v2_mults.outputs, inflat)
    v2 = sum(v2inmul) % Defs.prime

    assert v1 == finalOutputs[0]
    assert v2 == sum(finalOutputs) % Defs.prime
Esempio n. 2
0
def run_test():
    # pylint: disable=global-variable-undefined,redefined-outer-name
    tinputs = [Defs.gen_random() for _ in xrange(0, nOutBits)]
    taus = [Defs.gen_random() for _ in xrange(0, nOutBits)]
    lcv.set_inputs(tinputs)
    assert lcv.outputs == VerifierIOMLExt.compute_beta(tinputs)

    inputs = [
        util.chi(util.numToBin(x, nOutBits), tinputs)
        for x in xrange(0, 2**nOutBits)
    ]

    global scratch
    global outputs

    scratch = list(inputs)
    outputs = list(inputs)

    def compute_next_value(tau):
        global scratch
        global outputs

        nscratch = []
        tauInv = (1 - tau) % Defs.prime

        for i in xrange(0, len(scratch) / 2):
            val = ((scratch[2 * i] * tauInv) +
                   (scratch[2 * i + 1] * tau)) % Defs.prime
            nscratch.append(val)

        del val
        scratch = nscratch

        #ndups = len(outputs) / len(scratch)
        #nouts = [ [val] * ndups for val in scratch ]
        outputs = scratch
        #outputs = [item for sublist in nouts for item in sublist]

    for i in xrange(0, nOutBits):
        assert lcv.inputs == inputs
        assert lcv.outputs == outputs
        assert lcv.scratch == scratch

        compute_next_value(taus[i])
        lcv.next_round(taus[i])

        assert outputs == lcv.outputs
        assert scratch == lcv.scratch

    assert lcv.prevPassValue == scratch[0]
    assert all([lcv.prevPassValue == elm[0] for elm in lcv.outputs_fact])
Esempio n. 3
0
 def __init__(self, nCopies, nInputs, in0vv, in1vv, typvv, muxvv=None):
     super(CircuitVerifierVecWitNIZK,
           self).__init__(nCopies, nInputs, in0vv, in1vv, typvv, muxvv)
     if Defs.track_fArith:
         self.rdl_sc_a = Defs.fArith().new_cat("%s_rdl_sc_a_%d" %
                                               (self.cat_label, hash(self)))
     else:
         self.rdl_sc_a = None
Esempio n. 4
0
def run_test(nOutBits, nValues):
    # pylint: disable=redefined-outer-name,global-variable-undefined
    inputs = [Defs.gen_random() for _ in xrange(0, nValues)]
    taus = [Defs.gen_random() for _ in xrange(0, nOutBits)]

    lcv.set_inputs(inputs)

    global scratch
    global outputs

    inputs += [0] * (2**nOutBits - nValues)
    scratch = list(inputs)
    outputs = list(inputs)

    def compute_next_value(tau):
        global scratch
        global outputs

        nscratch = []
        tauInv = (1 - tau) % Defs.prime

        for i in xrange(0, len(scratch) / 2):
            val = ((scratch[2 * i] * tauInv) +
                   (scratch[2 * i + 1] * tau)) % Defs.prime
            nscratch.append(val)

        del val
        scratch = nscratch

        ndups = len(outputs) / len(scratch)
        nouts = [[val] * ndups for val in scratch]
        outputs = [item for sublist in nouts for item in sublist]

    for i in xrange(0, nOutBits):
        assert lcv.inputs == inputs
        assert lcv.outputs == outputs
        assert lcv.scratch == scratch

        compute_next_value(taus[i])
        lcv.next_round(taus[i])

        if i < nOutBits - 1:
            assert outputs == lcv.outputs
            assert scratch == lcv.scratch

    assert lcv.prevPassValue == scratch[0]
Esempio n. 5
0
def rand_inputs(nInBits, nCopies, inLay=None):
    out = []

    if inLay is None:
        inLay = [None] * (2 ** nInBits)
    else:
        nInBits = util.clog2(len(inLay))
        inLay += [0] * (2 ** nInBits - len(inLay))

    for _ in xrange(0, nCopies):
        out.append([ Defs.gen_random() if elm is None else elm % Defs.prime for elm in inLay ])

    return out
Esempio n. 6
0
def speed_test(num_tests):
    nbits = random.randint(3, 8)
    taus = [Defs.gen_random() for _ in xrange(0, nbits)]
    inputs = [[Defs.gen_random() for _ in xrange(0, 2**nbits)]
              for _ in xrange(0, num_tests)]

    vim = VerifierIOMLExt(taus)
    runtime = time.time()
    for idx in xrange(0, num_tests):
        vim.compute_nosavebits(inputs[idx])
    runtime = time.time() - runtime

    runtime2 = time.time()
    for idx in xrange(0, num_tests):
        vim.compute_savebits(inputs[idx])
    runtime2 = time.time() - runtime2

    runtime3 = time.time()
    for idx in xrange(0, num_tests):
        vim.compute_sqrtbits(inputs[idx])
    runtime3 = time.time() - runtime3

    print "nBits: %d\nnosavebits: %f\nsavebits: %f\nsqrtbits: %f\n" % (
        nbits, runtime, runtime2, runtime3)
Esempio n. 7
0
    def __init__(self, nCopies, nInputs, in0vv, in1vv, typevv, muxvv=None):
        self.nCopies = nCopies
        self.nCopyBits = util.clog2(nCopies)
        self.nInputs = nInputs
        self.nInBits = util.clog2(nInputs)
        self.ckt_inputs = None
        self.ckt_outputs = None
        self.layerNum = 0
        self.roundNum = 0
        self.arith_ckt = None

        assert len(in0vv) == len(in1vv)
        assert len(in0vv) == len(typevv)
        assert muxvv is None or len(in0vv) == len(muxvv)
        if muxvv is None:
            muxvv = [None] * len(in0vv)

        # save circuit config for building layers later
        self.in0vv = in0vv
        self.in1vv = in1vv
        self.typevv = typevv
        self.muxvv = muxvv
        self.muxlen = max(
            max(muxv) if muxv is not None else 0 for muxv in muxvv) + 1
        self.muxbits = [0] * self.muxlen

        # build instrumentation
        if Defs.track_fArith:
            fArith = Defs.fArith()
            self.comp_h = fArith.new_cat("p_comp_h_%d" % hash(self))
            self.comp_b = fArith.new_cat("p_comp_b_%d" % hash(self))
            self.comp_chi = fArith.new_cat("p_comp_chi_%d" % hash(self))
            self.comp_v = fArith.new_cat("p_comp_v_%d" % hash(self))
            self.comp_v_fin = fArith.new_cat("p_comp_v_fin_%d" % hash(self))
            self.comp_out = fArith.new_cat("p_comp_out_%d" % hash(self))
        else:
            self.comp_h = None
            self.comp_b = None
            self.comp_chi = None
            self.comp_v = None
            self.comp_v_fin = None
            self.comp_out = None

        # layer provers
        self.layer = None
        self.layer_inbits = [self.nInBits] + [
            util.clog2(len(in0v)) for in0v in in0vv[:-1]
        ]
Esempio n. 8
0
def speed_test(num_tests):
    nBits = random.randint(3, 8)
    inputs = [ [ Defs.gen_random() for _ in xrange(0, nBits)  ] for _ in xrange(0, num_tests) ]

    lcb = LayerComputeBeta(nBits)
    lcb.other_factors = []
    runtime = time.time()
    for idx in xrange(0, num_tests):
        lcb.set_inputs(inputs[idx])
    runtime = time.time() - runtime

    runtime2 = time.time()
    for idx in xrange(0, num_tests):
        VerifierIOMLExt.compute_beta(inputs[idx])
    runtime2 = time.time() - runtime2

    print "nBits: %d\nLayerComputeBeta: %f\nVerifierIOMLExt: %f\n" % (nBits, runtime, runtime2)
Esempio n. 9
0
    def __init__(self, nCopies, nInputs, in0vv, in1vv, typvv, muxvv=None):
        self.nCopies = nCopies
        self.nCopyBits = util.clog2(nCopies)
        self.nInputs = nInputs
        self.nInBits = util.clog2(nInputs)
        self.nCktInputs = nInputs  # these don't change even with RDL
        self.nCktBits = self.nInBits
        self.prover = None
        self.in0vv = in0vv
        self.in1vv = in1vv
        self.typvv = typvv
        self.muxvv = muxvv
        self.muxbits = None
        self.inputs = []
        self.outputs = []
        self.mlx_w1 = []
        self.mlx_w2 = []
        self.prover_fresh = False

        if Defs.track_fArith:
            fArith = Defs.fArith()
            self.in_a = fArith.new_cat("%s_in_%d" %
                                       (self.cat_label, hash(self)))
            self.out_a = fArith.new_cat("%s_out_%d" %
                                        (self.cat_label, hash(self)))
            self.sc_a = fArith.new_cat("%s_sc_%d" %
                                       (self.cat_label, hash(self)))
            self.tV_a = fArith.new_cat("%s_tV_%d" %
                                       (self.cat_label, hash(self)))
            self.tP_a = fArith.new_cat("%s_tP_%d" %
                                       (self.cat_label, hash(self)))
            self.nlay_a = fArith.new_cat("%s_nlay_%d" %
                                         (self.cat_label, hash(self)))
        else:
            self.in_a = None
            self.out_a = None
            self.sc_a = None
            self.tV_a = None
            self.tP_a = None
            self.nlay_a = None

        # nOutBits and nInBits for each layer
        self.layOutBits = [
            util.clog2(len(lay)) for lay in reversed(self.in0vv)
        ]
        self.layInBits = self.layOutBits[1:] + [self.nInBits]
Esempio n. 10
0
    def __init__(self, nCopies, nInputs, in0vv, in1vv, typvv, muxvv=None):
        if Defs.track_fArith:
            fArith = Defs.fArith()
            self.com_p_a = fArith.new_cat("%s_com_p_%d" %
                                          (self.cat_label, hash(self)))
            self.com_q_a = fArith.new_cat("%s_com_q_%d" %
                                          (self.cat_label, hash(self)))
            com_rec = (self.com_p_a, self.com_q_a)
        else:
            self.com_p_a = None
            self.com_q_a = None
            com_rec = None

        # set up Pederson commitment first so the prime field is correct!
        self.com = self.commit_type(Defs.curve, com_rec)
        self.com.set_field()
        super(CircuitVerifierNIZK, self).__init__(nCopies, nInputs, in0vv,
                                                  in1vv, typvv, muxvv)
Esempio n. 11
0
    def __init__(self, rdl, nInBits):
        # pylint: disable=protected-access
        self.rdl = util.flatten(rdl)
        self.nOutBits = util.clog2(len(self.rdl))
        self.nInBits = nInBits
        self.roundNum = 0

        self.muls = None

        circuit = type('', (object, ), {
            'comp_chi': None,
            'comp_v_fin': None,
            'comp_out': None
        })()
        self.circuit = circuit

        if Defs.track_fArith:
            fArith = Defs.fArith()
            circuit.comp_chi = fArith.new_cat("p_rdl_comp_chi_%d" % hash(self))
            circuit.comp_v_fin = fArith.new_cat("p_rdl_comp_v_fin_%d" %
                                                hash(self))
            circuit.comp_out = fArith.new_cat("p_rdl_comp_out_%d" % hash(self))
        else:
            circuit.comp_chi = None
            circuit.comp_v_fin = None
            circuit.comp_out = None

        self.inputs = []
        self.output = []

        # z1chi computation subckt
        # this uses V's fast beta evaluation (in set_z)
        self.compute_z1chi = None
        self.compute_z1chi_2 = None

        # everything is 2nd order, so we only need three eval points
        self.compute_v_final = LayerComputeV(nInBits, circuit.comp_v_fin)
        self.compute_v_final.set_other_factors([util.THIRD_EVAL_POINT])

        # pergate computation subckts for "early" rounds
        self.gates = []
        for (outNum, inNum) in enumerate(self.rdl):
            self.gates.append(
                gateprover.PassGateProver(False, inNum, 0, outNum, self, 0))
Esempio n. 12
0
    def set_inputs(self, inputs):
        assert len(inputs) == self.nCopies
        assert self.layerNum == 0

        self.ckt_inputs = inputs
        self.ckt_outputs = []

        # build arith circuit
        self.arith_ckt = ArithCircuitIncrementalBuilder(
            self.nCopies, self.nInputs, self.in0vv, self.in1vv, self.typevv,
            self.muxvv)
        self.arith_ckt.set_muxbits(self.muxbits)
        if Defs.track_fArith:
            self.arith_ckt.set_rec(Defs.fArith().new_cat("p_comp_arith_%d" %
                                                         hash(self)))

        # record inputs to each layer prover and set inputs for each layer prover
        self.ckt_outputs = self.arith_ckt.run(inputs)
        self.build_layer()
        self.layer.set_inputs(self.arith_ckt.get_next())
Esempio n. 13
0
    def __init__(self, nCopies, nInputs, in0vv, in1vv, typvv, muxvv=None):
        if Defs.track_fArith:
            fArith = Defs.fArith()
            self.com_p_a = fArith.new_cat("%s_com_p_%d" %
                                          (self.cat_label, hash(self)))
            self.com_q_a = fArith.new_cat("%s_com_q_%d" %
                                          (self.cat_label, hash(self)))
            com_rec = (self.com_p_a, self.com_q_a)
        else:
            self.com_p_a = None
            self.com_q_a = None
            com_rec = None

        # set up Pederson commit first so that the prime field is correct!
        self.com = self.commit_type(Defs.curve, com_rec)
        self.com.set_field()
        self.fs = fs.FiatShamir(Defs.prime)
        self.nondet_gen = lambda inputs, _: inputs
        super(CircuitProverNIZK, self).__init__(nCopies, nInputs, in0vv, in1vv,
                                                typvv, muxvv)
Esempio n. 14
0
 def __init__(self, acrepr):
     self.acrepr = acrepr
     self.nondet_gen = lambda inputs: inputs
     #self.get_default = lambda _: 0
     self.get_default = lambda _: Defs.gen_random()
Esempio n. 15
0
def run_fennel(verifier_info):
    # set curve and prime
    Defs.curve = verifier_info.curve
    util.set_prime(libfennel.commit.MiraclEC.get_order(Defs.curve))

    # pylint doesn't seed to understand how classmethods are inherited from metclasses
    p_from_pws = verifier_info.proofType.ProverClass.from_pws  # pylint: disable=no-member
    v_from_pws = verifier_info.proofType.VerifierClass.from_pws  # pylint: disable=no-member
    pFile = pypws.parse_pws(verifier_info.pwsFile, str(Defs.prime))

    # handle RDL
    if verifier_info.rdlFile is not None:
        rFile = pypws.parse_pws_unopt(verifier_info.rdlFile, str(Defs.prime))
        (r_input_layer,
         rdl_map) = libfennel.parse_pws.parse_rdl(rFile, verifier_info.nCopies,
                                                  pFile[0])

    # either generate or read in proof
    if verifier_info.vProofFile is None:
        (input_layer, prv) = p_from_pws(pFile, verifier_info.nCopies)
        prv.build_prover()

        # set up RDL
        if verifier_info.rdlFile is not None:
            inputs = get_inputs(verifier_info, r_input_layer)
            prv.set_rdl(rdl_map, len(r_input_layer))
        else:
            inputs = get_inputs(verifier_info, input_layer)

        # handle nondeterminism options
        if verifier_info.ndBits is not None:
            prv.set_nondet_range(verifier_info.ndBits)
        if verifier_info.ndGen is not None:
            prv.set_nondet_gen(verifier_info.ndGen)
        if verifier_info.rvStart is not None and verifier_info.rvEnd is not None:
            prv.set_rval_range(verifier_info.rvStart, verifier_info.rvEnd)
        if verifier_info.witnessDiv is not None:
            prv.set_wdiv(verifier_info.witnessDiv)

        verifier_info.pStartTime = time.time()
        proof = prv.run(inputs)
        verifier_info.pEndTime = time.time()
    else:
        with open(verifier_info.vProofFile, 'r') as fh:
            proof = bz2.decompress(fh.read())

    verifier_info.Log.log(
        "Proof size: %d elems, %d bytes" % FiatShamir.proof_size(proof), True)

    # either verify or write out proof
    if verifier_info.pProofFile is None:
        (_, ver) = v_from_pws(pFile, verifier_info.nCopies)

        # set up RDL
        if verifier_info.rdlFile is not None:
            ver.set_rdl(rdl_map, len(r_input_layer))

        verifier_info.vStartTime = time.time()
        try:
            ver.run(proof)
        except Exception as e:  # pylint: disable=broad-except
            verifier_info.Log.log("Verification failed: %s" % e, True)
            verifier_info.Log.log(traceback.format_exc(), True)
        else:
            verifier_info.Log.log("Verification succeeded.", True)
        verifier_info.vEndTime = time.time()
    else:
        with open(verifier_info.pProofFile, 'w') as fh:
            fh.write(bz2.compress(proof))

    nInBits = util.clog2(len(input_layer))
    if verifier_info.rdlFile is not None:
        nInBits = util.clog2(len(r_input_layer))
    nCopies = verifier_info.nCopies
    nLayers = len(ver.in0vv) + 1 if verifier_info.rdlFile is not None else 0
    verifier_info.Log.log(
        "nInBits: %d, nCopies: %d, nLayers: %d" % (nInBits, nCopies, nLayers),
        verifier_info.showPerf)
    if Defs.track_fArith:
        verifier_info.Log.log(str(Defs.fArith()), verifier_info.showPerf)
Esempio n. 16
0
def run_one_test(nInBits, nCopies):
    nOutBits = nInBits

    circuit = _DummyCircuitProver(nCopies)

    (in0v, in1v, typv) = randutil.rand_ckt(nOutBits, nInBits)
    typc = [tc.cgate for tc in typv]
    inputs = randutil.rand_inputs(nInBits, nCopies)

    # compute outputs
    ckt = ArithCircuit()
    inCktLayer = ArithCircuitInputLayer(ckt, nOutBits)
    outCktLayer = ArithCircuitLayer(ckt, inCktLayer, in0v, in1v, typc)
    ckt.layers = [inCktLayer, outCktLayer]
    outputs = []
    for inp in inputs:
        ckt.run(inp)
        outputs.append(ckt.outputs)

    z1 = [Defs.gen_random() for _ in xrange(0, nOutBits)]
    z2 = [Defs.gen_random() for _ in xrange(0, circuit.nCopyBits)]

    outLayer = LayerProver(nInBits, circuit, in0v, in1v, typv)
    outLayer.set_inputs(inputs)
    outLayer.set_z(z1, z2, None, None, True)

    # mlExt of outputs
    outflat = util.flatten(outputs)
    inLayer_mults = LayerComputeBeta(nOutBits + outLayer.circuit.nCopyBits,
                                     z1 + z2)
    assert len(outflat) == len(inLayer_mults.outputs)
    inLayermul = util.mul_vecs(inLayer_mults.outputs, outflat)
    inLayerExt = sum(inLayermul) % Defs.prime

    w3 = [Defs.gen_random() for _ in xrange(0, circuit.nCopyBits)]
    w1 = [Defs.gen_random() for _ in xrange(0, nInBits)]
    w2 = [Defs.gen_random() for _ in xrange(0, nInBits)]

    outLayer.compute_outputs()
    initOutputs = outLayer.output

    assert inLayerExt == (initOutputs[0] + sum(initOutputs)) % Defs.prime

    for i in xrange(0, len(w3)):
        outLayer.next_round(w3[i])
        outLayer.compute_outputs()

    for i in xrange(0, len(w1)):
        outLayer.next_round(w1[i])
        outLayer.compute_outputs()

    for i in xrange(0, len(w2)):
        outLayer.next_round(w2[i])
        outLayer.compute_outputs()

    finalOutputs = outLayer.output

    # check the outputs by computing mlext of layer input directly

    inflat = util.flatten(inputs)

    v1_mults = LayerComputeBeta(outLayer.nInBits + outLayer.circuit.nCopyBits,
                                w1 + w3)
    assert len(inflat) == len(v1_mults.outputs)
    v1inmul = util.mul_vecs(v1_mults.outputs, inflat)
    v1 = sum(v1inmul) % Defs.prime

    v2_mults = LayerComputeBeta(outLayer.nInBits + outLayer.circuit.nCopyBits,
                                w2 + w3)
    assert len(inflat) == len(v2_mults.outputs)
    v2inmul = util.mul_vecs(v2_mults.outputs, inflat)
    v2 = sum(v2inmul) % Defs.prime

    assert v1 == finalOutputs[0]
    assert v2 == sum(finalOutputs) % Defs.prime
Esempio n. 17
0
def run_one_test(nbits, squawk, nbins, pattern):
    z = [Defs.gen_random() for _ in xrange(0, nbits)]

    inv = [Defs.gen_random() for _ in xrange(0, (2**nbits) - nbins)]
    if pattern is 0:
        inv += [0 for _ in xrange(0, nbins)]
    elif pattern is 1:
        inv += [1 for _ in xrange(0, nbins)]
    elif pattern == 2:
        inv += [(i % 2) for i in xrange(0, nbins)]
    elif pattern == 3:
        inv += [((i + 1) % 2) for i in xrange(0, nbins)]
    else:
        inv += [random.randint(0, 1) for _ in xrange(0, nbins)]

    assert len(inv) == (2**nbits)

    Defs.track_fArith = True
    fa = Defs.fArith()
    oldrec = fa.new_cat("old")
    newrec = fa.new_cat("new")
    nw2rec = fa.new_cat("nw2")

    oldbeta = LayerComputeBeta(nbits, z, oldrec)
    oldval = sum(util.mul_vecs(oldbeta.outputs, inv)) % Defs.prime
    oldrec.did_mul(len(inv))
    oldrec.did_add(len(inv) - 1)

    newcomp = VerifierIOMLExt(z, newrec)
    newval = newcomp.compute(inv)

    nw2comp = LayerComputeV(nbits, nw2rec)
    nw2comp.other_factors = []
    nw2comp.set_inputs(inv)
    for zz in z:
        nw2comp.next_round(zz)
    nw2val = nw2comp.prevPassValue

    assert oldval == newval, "error for inputs (new) %s : %s" % (str(z),
                                                                 str(inv))
    assert oldval == nw2val, "error for inputs (nw2) %s : %s" % (str(z),
                                                                 str(inv))

    if squawk:
        print
        print "nbits: %d" % nbits
        print "OLD: %s" % str(oldrec)
        print "NEW: %s" % str(newrec)
        print "NW2: %s" % str(nw2rec)

    betacomp = VerifierIOMLExt.compute_beta(z)
    beta_lo = random.randint(0, 2**nbits - 1)
    beta_hi = random.randint(beta_lo, 2**nbits - 1)
    betacomp2 = VerifierIOMLExt.compute_beta(z, None, 1, beta_lo, beta_hi)
    # make sure that the right range was generated, and correctly
    assert len(betacomp) == len(betacomp2)
    assert all([b is None for b in betacomp2[:beta_lo]])
    assert all([b is not None for b in betacomp2[beta_lo:beta_hi + 1]])
    assert all([b is None for b in betacomp2[beta_hi + 1:]])
    assert all([
        b2 == b if b2 is not None else True
        for (b, b2) in zip(betacomp, betacomp2)
    ])

    return newrec.get_counts()
Esempio n. 18
0
    def run(self, inputs, muxbits=None):
        self.build_prover()
        self.prover_fresh = False

        ############
        # 0. Setup #
        ############
        assert self.prover is not None

        # set muxbits
        self.muxbits = muxbits
        if muxbits is not None:
            self.prover.set_muxbits(muxbits)

        # set inputs and outputs
        self.prover.set_inputs(inputs)
        self.inputs = []
        for ins in inputs:
            self.inputs.extend(ins + [0] * (2**self.nInBits - len(ins)))

        # pad to power-of-2 number of copies
        assert util.clog2(len(self.inputs)) == self.nInBits + self.nCopyBits
        self.inputs += [0] * (2 ** (self.nInBits + self.nCopyBits) - len(self.inputs))

        ###############################################
        # 1. Compute multilinear extension of outputs #
        ###############################################
        self.outputs = util.flatten(self.prover.ckt_outputs)
        nOutBits = util.clog2(len(self.in0vv[-1]))
        assert util.clog2(len(self.outputs)) == nOutBits + self.nCopyBits

        # pad out to power-of-2 number of copies
        self.outputs += [0] * (2 ** (nOutBits + self.nCopyBits) - len(self.outputs))

        # generate random point in (z1, z2) \in F^{nOutBits + nCopyBits}
        z1 = [ Defs.gen_random() for _ in xrange(0, nOutBits) ]
        z1_2 = None
        z2 = [ Defs.gen_random() for _ in xrange(0, self.nCopyBits) ]
        if Defs.track_fArith:
            self.sc_a.did_rng(nOutBits + self.nCopyBits)

        # if the AC has only one layer, tell P to give us H(.)
        muls = None
        project_line = len(self.in0vv) == 1
        self.prover.set_z(z1, z2, None, None, project_line)

        # eval mlext of output at (z1,z2)
        expectNext = VerifierIOMLExt(z1 + z2, self.out_a).compute(self.outputs)

        ##########################################
        # 2. Interact with prover for each layer #
        ##########################################
        for lay in xrange(0, len(self.in0vv)):
            nInBits = self.layInBits[lay]
            nOutBits = self.layOutBits[lay]

            # random coins for this round
            w3 = [ Defs.gen_random() for _ in xrange(0, self.nCopyBits) ]
            w1 = [ Defs.gen_random() for _ in xrange(0, nInBits) ]
            w2 = [ Defs.gen_random() for _ in xrange(0, nInBits) ]
            if Defs.track_fArith:
                self.sc_a.did_rng(2*nInBits + self.nCopyBits)

            # convenience
            ws = w3 + w1 + w2

            ###################
            ### A. Sumcheck ###
            ###################
            for rd in xrange(0, 2 * nInBits + self.nCopyBits):
                # get output from prv and check against expected value
                outs = self.prover.get_outputs()
                gotVal = (outs[0] + sum(outs)) % Defs.prime
                if Defs.track_fArith:
                    self.sc_a.did_add(len(outs))

                assert expectNext == gotVal, "Verification failed in round %d of layer %d" % (rd, lay)

                # go to next round
                self.prover.next_round(ws[rd])
                expectNext = util.horner_eval(outs, ws[rd], self.sc_a)

            outs = self.prover.get_outputs()

            if project_line:
                v1 = outs[0] % Defs.prime
                v2 = sum(outs) % Defs.prime
                if Defs.track_fArith:
                    self.tV_a.did_add(len(outs)-1)
            else:
                assert len(outs) == 2
                v1 = outs[0]
                v2 = outs[1]

            ############################################
            ### B. Evaluate mlext of wiring predicates #
            ############################################
            tV_eval = self.eval_tV(lay, z1, z2, w1, w2, w3, v1, v2, z1_2 is not None, muls)

            # check that we got the correct value from the last round of the sumcheck
            assert expectNext == tV_eval, "Verification failed computing tV for layer %d" % lay

            ###############################
            ### C. Extend to next layer ###
            ###############################
            project_next = lay == len(self.in0vv) - 2

            if project_line:
                tau = Defs.gen_random()
                muls = None
                expectNext = util.horner_eval(outs, tau, self.nlay_a)
                # z1 = w1 + ( w2 - w1 ) * tau
                z1 = [ (elm1 + (elm2 - elm1) * tau) % Defs.prime for (elm1, elm2) in izip(w1, w2) ]
                z1_2 = None
                if Defs.track_fArith:
                    self.nlay_a.did_sub(len(w1))
                    self.nlay_a.did_mul(len(w1))
                    self.nlay_a.did_add(len(w1))
                    self.sc_a.did_rng()
            else:
                muls = (Defs.gen_random(), Defs.gen_random())
                tau = None
                expectNext = ( muls[0] * v1 + muls[1] * v2 ) % Defs.prime
                self.prover.next_layer(muls, project_next)
                z1 = w1
                z1_2 = w2
                if Defs.track_fArith:
                    self.nlay_a.did_add()
                    self.nlay_a.did_mul(2)
                    self.sc_a.did_rng(2)

            project_line = project_next
            z2 = w3

        ##############################################
        # 3. Compute multilinear extension of inputs #
        ##############################################
        # Finally, evaluate mlext of input at z1, z2
        input_mlext = VerifierIOMLExt(z1 + z2, self.in_a)
        input_mlext_eval = input_mlext.compute(self.inputs)

        assert input_mlext_eval == expectNext, "Verification failed checking input mlext"