Exemplo n.º 1
0
    def create(a, M, A=None, N=None, m=None,
            a_packed=None, M_packed=None, A_packed=None, N_packed=None):
        if A==None:
            A = ed25519.Point.B_times(a)
        if N==None:
            N = M * a

        if A_packed==None:
            A_packed = A.pack()
        if N_packed==None:
            N_packed = N.pack()
        if M_packed==None:
            M_packed = M.pack()
        if a_packed==None:
            a_packed = ed25519.scalar_pack(a)
        
        r = ed25519.scalar_unpack(common.sha256(
            b"DHTProof" + a_packed + M_packed))
        R_B = ed25519.Point.B_times(r)

        if m==None:
            R_M = M * r 
        else:
            R_M = ed25519.Point.B_times(m*r)

        R_M_packed = R_M.pack()
        R_B_packed = R_B.pack()

        h = ed25519.scalar_unpack(common.sha256(
            A_packed + M_packed + N_packed + R_M_packed + R_B_packed))

        s = (r + h * a) % ed25519.l

        return DHTProof(R_M, R_B, s, 
                R_M_packed=R_M_packed, R_B_packed=R_B_packed)
Exemplo n.º 2
0
    def test_dht_proof_create(self):
        cbuf = ristretto.ffi.new("unsigned char[]", 96)
        buf = ristretto.ffi.buffer(cbuf)

        for i in range(10):
            a = ed25519.scalar_random()
            m = ed25519.scalar_random()
            A = ed25519.Point.B_times(a)
            M = ed25519.Point.B_times(m)
            N = M*a

            if i%2==0:
                cm = scalar_to_c(m)
                cM = ristretto.ffi.NULL
            else:
                cm = ristretto.ffi.NULL
                cM = point_to_c(M)
            
            ristretto.lib.dht_proof_create(
                    cbuf, 
                    scalar_to_c(a), ed25519.scalar_pack(a), A.pack(),
                    cm, cM, M.pack(),
                    point_to_c(N), N.pack())

            self.assertEqual(schnorr.DHTProof.create(a,M).pack(),
                    buf[:])
Exemplo n.º 3
0
    def decrypt(self, pseudonyms, key):
        n = len(pseudonyms)

        ckey = ristretto.ffi.new("group_scalar*")
        cpoints = ristretto.ffi.new("group_ge[]", n)
        ctriples = ristretto.ffi.new("elgamal_triple[]", n)

        assert (0 == ristretto.lib.group_scalar_unpack(
            ckey, ed25519.scalar_pack(key)))

        cerror_codes = ristretto.ffi.new("int[]", n)

        ristretto.lib.elgamal_triples_unpack(
            ctriples, b''.join([pseudonym.data for pseudonym in pseudonyms]),
            cerror_codes, n)

        for i in range(n):
            if cerror_codes[i] != 0:
                raise InvalidArgument(f"couldn't unpack {i}th triple")

        ristretto.lib.elgamal_triples_decrypt(cpoints, ctriples, ckey, n)

        cbuf = ristretto.ffi.new("unsigned char[]", 32 * n)

        ristretto.lib.group_ges_pack(cbuf, cpoints, n)

        buf = ristretto.ffi.buffer(cbuf)

        for i in range(n):
            pseudonyms[i].data = buf[i * 32:(i + 1) * 32]
Exemplo n.º 4
0
    def encrypt(self, pseudonyms, target, rs):
        n = len(pseudonyms)
        assert (n == len(rs))

        ctarget = ristretto.ffi.new("group_ge*")
        crs = ristretto.ffi.new("group_scalar[]", n)
        cpoints = ristretto.ffi.new("group_ge[]", n)
        ctriples = ristretto.ffi.new("elgamal_triple[]", n)

        assert (0 == ristretto.lib.group_ge_unpack(ctarget, target.pack()))

        for i in range(n):
            ristretto.lib.group_scalar_unpack(ristretto.ffi.addressof(crs, i),
                                              ed25519.scalar_pack(rs[i]))

        for i in range(n):
            ristretto.lib.group_ge_unpack(ristretto.ffi.addressof(cpoints, i),
                                          pseudonyms[i].data)

        ristretto.lib.elgamal_triples_encrypt(ctriples, cpoints, ctarget, crs,
                                              n)

        cbuf = ristretto.ffi.new("unsigned char[]", 96 * n)

        ristretto.lib.elgamal_triples_pack(cbuf, ctriples, n)

        buf = ristretto.ffi.buffer(cbuf)

        for i in range(n):
            pseudonyms[i].data = buf[i * 96:(i + 1) * 96]
Exemplo n.º 5
0
    def certified_component_create(self, cc_protobuf, base_powers, base_scalar,
                                   exponent):

        ccc = ristretto.ffi.new("certified_component*")

        N = len(list(schnorr.ones_of(exponent)))

        cdht_proofs = ristretto.ffi.new("unsigned char[]", 96 * max(N - 1, 0))
        cpartial_products = ristretto.ffi.new("unsigned char[]",
                                              32 * max(N - 2, 0))
        ccc.product_proof.dht_proofs = cdht_proofs
        ccc.product_proof.partial_products = cpartial_products

        cbase_powers = ristretto.ffi.new("unsigned char[]", 32 * 253)
        cbase_powers_buffer = ristretto.ffi.buffer(cbase_powers)

        for i in range(253):
            cbase_powers[i * 32:(i + 1) * 32] = base_powers[i]

        cbase_scalar = ristretto.ffi.new("group_scalar*")
        ristretto.lib.group_scalar_unpack(cbase_scalar,
                                          ed25519.scalar_pack(base_scalar))

        cexponent = ristretto.ffi.new("group_scalar*")
        ristretto.lib.group_scalar_unpack(cexponent,
                                          ed25519.scalar_pack(exponent))

        ristretto.lib.certified_component_create(ccc, cbase_powers,
                                                 cbase_scalar, cexponent)

        cc_protobuf.component = ristretto.ffi.buffer(ccc.component)[:]
        #N = ccc.product_proof.number_of_factors

        dht_proofs_buffer = ristretto.ffi.buffer(ccc.product_proof.dht_proofs,
                                                 96 * max(N - 1, 0))
        partial_products_buffer = ristretto.ffi.buffer(
            ccc.product_proof.partial_products, 32 * max(N - 2, 0))

        for i in range(max(N - 1, 0)):
            cc_protobuf.product_proof.dht_proofs.append(
                dht_proofs_buffer[i * 96:(i + 1) * 96])

        for i in range(max(N - 2, 0)):
            cc_protobuf.product_proof.partial_products.append(
                partial_products_buffer[i * 32:(i + 1) * 32])
Exemplo n.º 6
0
    def Enroll(self, request, context):
        common_name = common.authenticate(context)

        # TODO: have a smarter check
        return_components = (common_name == b"PEP3 investigator"
                             or common_name == b"PEP3 researcher")

        response = pep3_pb2.EnrollmentResponse()

        # 1. [ REMOVED ]

        # 2. set response.by_shard[...].private_local_keys
        #        and response.components[...].keys

        e = ed25519.scalar_unpack(common.sha256(common_name))
        # "e" is the exponent used to compute the key/pseudonym components

        for shard, shard_secrets in self.pep.secrets.by_shard.items():
            for domain, domain_secrets in shard_secrets.by_domain.items():
                x = ed25519.scalar_unpack(domain_secrets.private_master_key)
                k = ed25519.scalar_unpack(domain_secrets.key_component_secret)

                k_local = pow(k, e, ed25519.l)

                if return_components:
                    self.pep._cryptopu.certified_component_create(
                            response.components[shard].keys[domain],
                            self.pep.global_config.\
                                    components[shard].keys[domain].\
                                    base_times_two_to_the_power_of,
                            k, e)

                x_local = (k_local * x) % ed25519.l

                response.by_shard[shard].private_local_keys[domain]\
                        = ed25519.scalar_pack(x_local)

        # 3. set response.components[...].pseudonym

        if return_components:
            for shard in self.pep.config.shards:
                s = ed25519.scalar_unpack(
                        self.pep.secrets.by_shard[shard]\
                                .pseudonym_component_secret)

                self.pep._cryptopu.certified_component_create(
                        response.components[shard].pseudonym,
                        self.pep.global_config.\
                                components[shard].pseudonym.\
                                base_times_two_to_the_power_of,
                        s, e)

        return response
Exemplo n.º 7
0
    def test_triple_rsk(self):
        m = ed25519.Point.random()
        y = ed25519.Point.random()
        r = ed25519.scalar_random()
        s = ed25519.scalar_random()
        k = ed25519.scalar_random()
        r2 = ed25519.scalar_random()

        triple = elgamal.encrypt(m,y,r2)
        
        ctriple = ristretto.ffi.new("elgamal_triple*")
        ck = ristretto.ffi.new("group_scalar*")
        cs = ristretto.ffi.new("group_scalar*")
        cr = ristretto.ffi.new("group_scalar*")

        ristretto.lib.elgamal_triple_unpack(ctriple, 
                triple.blinding.pack() + 
                triple.core.pack() + 
                triple.target.pack())
        ristretto.lib.group_scalar_unpack(ck, ed25519.scalar_pack(k))
        ristretto.lib.group_scalar_unpack(cs, ed25519.scalar_pack(s))
        ristretto.lib.group_scalar_unpack(cr, ed25519.scalar_pack(r))

        ristretto.lib.elgamal_triple_rsk(ctriple, ctriple, ck, cs, cr)

        cbuf = ristretto.ffi.new("unsigned char []", 96)
        ristretto.lib.elgamal_triple_pack(cbuf, ctriple)
        
        buf = ristretto.ffi.buffer(cbuf)

        triple2 = elgamal.Triple(
                ed25519.Point.unpack(buf[0:32]),
                ed25519.Point.unpack(buf[32:64]),
                ed25519.Point.unpack(buf[64:96]))
        
        triple = triple.rsk(k,s,r)
        
        self.assertEqual(triple.blinding, triple2.blinding)
        self.assertEqual(triple.core, triple2.core)
        self.assertEqual(triple.target, triple2.target)
Exemplo n.º 8
0
    def component_public_part(self, scalar):
        cy = ristretto.ffi.new("group_ge[]", 253)
        cx = ristretto.ffi.new("group_scalar*")

        assert (0 == ristretto.lib.group_scalar_unpack(
            cx, ed25519.scalar_pack(scalar)))

        ristretto.lib.component_public_part(cy, cx)

        cbuf = ristretto.ffi.new("unsigned char[]", 32 * 253)
        ristretto.lib.group_ges_pack(cbuf, cy, 253)

        return [bytes(cbuf[i * 32:(i + 1) * 32]) for i in range(253)]
Exemplo n.º 9
0
    def certified_component_is_valid_for(self, cc_protobuf, base_powers,
                                         exponent):

        ccc = ristretto.ffi.new("certified_component*")

        ccc.component = cc_protobuf.component

        N = len(list(schnorr.ones_of(exponent)))

        cdht_proofs = ristretto.ffi.new("unsigned char[]", 96 * max(N - 1, 0))
        cpartial_products = ristretto.ffi.new("unsigned char[]",
                                              32 * max(N - 2, 0))

        ccc.product_proof.dht_proofs = cdht_proofs
        ccc.product_proof.partial_products = cpartial_products
        ccc.product_proof.number_of_factors = N

        dht_proofs_buffer = ristretto.ffi.buffer(ccc.product_proof.dht_proofs,
                                                 96 * max(N - 1, 0))
        partial_products_buffer = ristretto.ffi.buffer(
            ccc.product_proof.partial_products, 32 * max(N - 2, 0))

        if len(cc_protobuf.product_proof.dht_proofs) != max(N - 1, 0):
            return False
        if len(cc_protobuf.product_proof.partial_products) != max(N - 2, 0):
            return False

        for i, data in enumerate(cc_protobuf.product_proof.dht_proofs):
            dht_proofs_buffer[96 * i:96 * (i + 1)] = data

        for i, data in enumerate(cc_protobuf.product_proof.partial_products):
            partial_products_buffer[32 * i:32 * (i + 1)] = data

        cbase_powers = ristretto.ffi.new("unsigned char[]", 32 * 253)
        cbase_powers_buffer = ristretto.ffi.buffer(cbase_powers)

        for i in range(253):
            cbase_powers[i * 32:(i + 1) * 32] = base_powers[i]

        cexponent = ristretto.ffi.new("group_scalar*")
        ristretto.lib.group_scalar_unpack(cexponent,
                                          ed25519.scalar_pack(exponent))

        return ristretto.lib.certified_component_is_valid_for(
            ccc, cbase_powers, cexponent) != 0
Exemplo n.º 10
0
def fill_local_config_messages(config, secrets):
    config.domains.append("pseudonym")
    pu = cryptopu.CryptoPU()

    # configure peers
    peer_names = ('A', 'B', 'C', 'D', 'E')

    for letter in peer_names:
        config.peers.get_or_create(letter)
        secrets.peers.get_or_create(letter)
        secrets.peers[letter].reminders_hmac_secret = os.urandom(32)

    private_keys = {}
    for domain in config.domains:
        private_keys[domain] = 1  # will be set below

    peer_triples = tuple(itertools.combinations(peer_names,3))

    for name1, name2, name3 in peer_triples:
        shard = name1 + name2 + name3  # ABC, ADE, ...
        config.shards.append(shard)

        config.peers[name1].shards.append(shard)
        config.peers[name2].shards.append(shard)
        config.peers[name3].shards.append(shard)

        secrets.peers[name1].by_shard[shard].pseudonym_component_secret\
                = secrets.peers[name2].by_shard[shard]\
                        .pseudonym_component_secret\
                = secrets.peers[name3].by_shard[shard]\
                        .pseudonym_component_secret\
                = s_packed = os.urandom(32)
        s = ed25519.scalar_unpack(s_packed)

        config.components[shard].pseudonym.base_times_two_to_the_power_of\
                .extend(pu.component_public_part(s))

        for domain in config.domains:
            private_key_part = ed25519.scalar_random()

            secrets.peers[name1].by_shard[shard].by_domain[domain]\
                        .private_master_key\
                    = secrets.peers[name2].by_shard[shard].by_domain[domain]\
                            .private_master_key\
                    = secrets.peers[name3].by_shard[shard].by_domain[domain]\
                            .private_master_key\
                    = ed25519.scalar_pack(private_key_part)

            private_keys[domain] *= private_key_part
            private_keys[domain] %= ed25519.l

            secrets.peers[name1].by_shard[shard].by_domain[domain]\
                            .key_component_secret\
                    = secrets.peers[name2].by_shard[shard].by_domain[domain]\
                            .key_component_secret\
                    = secrets.peers[name3].by_shard[shard].by_domain[domain]\
                            .key_component_secret\
                    = k_packed = os.urandom(32)
            k = ed25519.scalar_unpack(k_packed)

            config.components[shard].keys[domain]\
                    .base_times_two_to_the_power_of\
                    .extend(pu.component_public_part(k))

    # generate certificates and port numbers for the servers
    root_key = crypto.PKey()
    root_key.generate_key(crypto.TYPE_RSA, 1024)
    secrets.root_certificate_keys.tls = crypto.dump_privatekey(
                    crypto.FILETYPE_PEM, root_key)
    root_crt = crypto.X509()
    root_crt.get_subject().CN = "PEP3 TLS Root"
    root_crt.set_serial_number(1)
    root_crt.gmtime_adj_notBefore(0)
    root_crt.gmtime_adj_notAfter(356*24*60*60)
    root_crt.set_issuer(root_crt.get_subject())
    root_crt.set_pubkey(root_key)
    root_crt.sign(root_key, 'sha256')

    config.root_certificates.tls = crypto.dump_certificate(
            crypto.FILETYPE_PEM, root_crt)

    port = 1234
    number_of_cpus = multiprocessing.cpu_count()

    for server_type_name, server_type in SERVER_TYPES.items():
        if server_type.is_singleton:
            server_configs = { None: getattr(config,server_type_name) }
        else:
            server_configs = getattr(config, server_type_name+"s")

        for name, server_config in server_configs.items():
            server_config.number_of_threads = number_of_cpus
        
            # set address and port
            server_config.location.address = \
                    server_config.location.listen_address = f"localhost:{port}"
            port += 1

            # set tls certificate
            server_key = crypto.PKey()
            server_key.generate_key(crypto.TYPE_RSA, 1024)
            server_crt = crypto.X509()
            server_crt.get_subject().CN = "PEP3 " + server_type_name
            server_crt.set_serial_number(1)
            server_crt.gmtime_adj_notBefore(0)
            server_crt.gmtime_adj_notAfter(356*24*60*60)
            server_crt.set_issuer(root_crt.get_subject())
            ext = crypto.X509Extension(b"subjectAltName",False, b"DNS:localhost")
            server_crt.add_extensions([ext])
            server_crt.set_pubkey(server_key)
            server_crt.sign(root_key, 'sha256')
        
            server_config.location.tls_certificate = crypto.dump_certificate(
                crypto.FILETYPE_PEM, server_crt)

            if name==None:
                server_secrets = getattr(secrets,server_type_name)
            else:
                server_secrets = getattr(secrets,server_type_name+"s")[name]

            server_secrets.tls_certificate_key = crypto.dump_privatekey(
                    crypto.FILETYPE_PEM, server_key)

    # The following uri makes sqlalchemy.create_engine use sqlite's :memory:
    # in-memory database.
    config.database.engine.uri = "sqlite://" 
    config.database.engine.connect_args['check_same_thread'] = False
    config.database.engine.poolclass = 'StaticPool'
    config.database.engine.create_tables = True
    config.database.number_of_threads = 1

    # generate keys for warrants
    warrant_key = crypto.PKey()
    warrant_key.generate_key(crypto.TYPE_RSA, 1024)
    secrets.root_certificate_keys.warrants = crypto.dump_privatekey(
                    crypto.FILETYPE_PEM, warrant_key)
    warrant_crt = crypto.X509()
    warrant_crt.get_subject().CN = "PEP3 Warrant Root"
    warrant_crt.set_serial_number(1)
    warrant_crt.gmtime_adj_notBefore(0)
    warrant_crt.gmtime_adj_notAfter(356*24*60*60)
    warrant_crt.set_issuer(warrant_crt.get_subject())
    warrant_crt.set_pubkey(warrant_key)
    warrant_crt.sign(warrant_key, 'sha256')
    warrant_crt_data = crypto.dump_certificate(crypto.FILETYPE_PEM, 
            warrant_crt)

    config.root_certificates.warrants = warrant_crt_data

    # for Collector.Store
    warrant = config.collector.warrants.to_sf
    act = warrant.act
    act.target = b"PEP3 storage_facility"
    act.encrypt_for = b"PEP3 storage_facility"
    act.source = b"plaintext"
    act.actor = b"PEP3 collector"

    warrant.signature = crypto.sign(
            warrant_key, act.SerializeToString(), 'sha256')

    # for Researcher.Query
    warrant = config.researcher.warrants.from_me_to_sf
    act = warrant.act
    act.target = b"PEP3 storage_facility"
    act.encrypt_for = b"PEP3 storage_facility"
    act.source = b"PEP3 researcher"
    act.actor = b"PEP3 researcher"

    warrant.signature = crypto.sign(
            warrant_key, act.SerializeToString(), 'sha256')

    warrant = config.researcher.warrants.from_sf_to_me
    act = warrant.act
    act.target = b"PEP3 researcher"
    act.encrypt_for = b"PEP3 researcher"
    act.source = b"PEP3 storage_facility"
    act.actor = b"PEP3 researcher"

    warrant.signature = crypto.sign(
            warrant_key, act.SerializeToString(), 'sha256')

    # for Investigator.Query
    warrant = config.investigator.warrants.from_me_to_sf
    act = warrant.act
    act.target = b"PEP3 storage_facility"
    act.encrypt_for = b"PEP3 storage_facility"
    act.source = b"PEP3 investigator"
    act.actor = b"PEP3 investigator"

    warrant.signature = crypto.sign(
            warrant_key, act.SerializeToString(), 'sha256')

    warrant = config.investigator.warrants.from_sf_to_me
    act = warrant.act
    act.target = b"PEP3 investigator"
    act.encrypt_for = b"PEP3 investigator"
    act.source = b"PEP3 storage_facility"
    act.actor = b"PEP3 investigator"

    warrant.signature = crypto.sign(
            warrant_key, act.SerializeToString(), 'sha256')

    # describe tables used by the database
    columns = config.db_desc['peped_flows'].columns
    columns['p_src_ip'] = 'pseudonymized'
    columns['p_dst_ip'] = 'pseudonymized'

    for name in ('start_time', 'end_time', 'src_port', 'dst_port',
            'protocol', 'packets', 'bytes'):
        columns[name] = 'plain'

    #
    config.batchsize = 1024
Exemplo n.º 11
0
 def pack(self):
     return self._R_M_packed + self._R_B_packed \
             + ed25519.scalar_pack(self._s)
Exemplo n.º 12
0
    def test_triples_rsk(self):
        n = 5

        target = ed25519.Point.random()

        triples = [ elgamal.encrypt(
                    ed25519.Point.random(),
                    target,
                    ed25519.scalar_random()) for i in range(n) ]

        cblindings = ristretto.ffi.new("group_ge[]", n) 
        ccores = ristretto.ffi.new("group_ge[]", n) 
        ctarget = ristretto.ffi.new("group_ge*")

        for i in range(n):
            self.assertEqual(0, ristretto.lib.group_ge_unpack(
                ristretto.ffi.addressof(cblindings, i),
                triples[i].blinding.pack()))
            self.assertEqual(0, ristretto.lib.group_ge_unpack(
                ristretto.ffi.addressof(ccores, i),
                triples[i].core.pack()))

        self.assertEqual(0, 
                ristretto.lib.group_ge_unpack(ctarget, target.pack()))

        s = ed25519.scalar_random()
        k = ed25519.scalar_random()
        rs = [ ed25519.scalar_random() for i in range(n) ]
        
        cs = ristretto.ffi.new("group_scalar*")
        ck = ristretto.ffi.new("group_scalar*")
        crs = ristretto.ffi.new("group_scalar[]", n)

        self.assertEqual(0, ristretto.lib.group_scalar_unpack(cs, 
            ed25519.scalar_pack(s)))
        self.assertEqual(0, ristretto.lib.group_scalar_unpack(ck, 
            ed25519.scalar_pack(k)))

        for i in range(n):
            self.assertEqual(0, ristretto.lib.group_scalar_unpack(
                ristretto.ffi.addressof(crs, i), 
                ed25519.scalar_pack(rs[i])))

        ristretto.lib.elgamal_triples_rsk(
                cblindings, ccores, cblindings, ccores,
                ctarget, ctarget, ck, cs, crs, n)
        
        cbuf = ristretto.ffi.new("unsigned char[]", 32)
        buf = ristretto.ffi.buffer(cbuf)

        for i in range(n):
            triple = triples[i].rsk(k,s,rs[i])

            ristretto.lib.group_ge_pack(cbuf,
                    ristretto.ffi.addressof(cblindings, i))
            self.assertEqual( buf[:], triple.blinding.pack())
            
            ristretto.lib.group_ge_pack(cbuf,
                    ristretto.ffi.addressof(ccores, i))
            self.assertEqual( buf[:], triple.core.pack())

            ristretto.lib.group_ge_pack(cbuf, ctarget)
            self.assertEqual( buf[:], triple.target.pack())
Exemplo n.º 13
0
def scalar_to_c(scalar):
    result = ristretto.ffi.new("group_scalar*")
    assert(0==ristretto.lib.group_scalar_unpack(result,
            ed25519.scalar_pack(scalar)))
    return result
Exemplo n.º 14
0
    def test_product_proof_create(self):
        N = 10

        factors_scalar = [ ed25519.scalar_random() for i in range(N) ]
        product_scalar = 1

        cfactors_scalar_packed = ristretto.ffi.new("unsigned char[]", 32*N)
        cfactors_packed = ristretto.ffi.new("unsigned char[]", 32*N)
        cfactors = ristretto.ffi.new("group_ge[]", N)
        cfactors_scalar_packed_buf \
                = ristretto.ffi.buffer(cfactors_scalar_packed)
        cfactors_packed_buf  = ristretto.ffi.buffer(cfactors_packed)

        for i in range(N):
            factor_scalar = factors_scalar[i]
            product_scalar *= factor_scalar

            factor_scalar_packed = ed25519.scalar_pack(factor_scalar)
            cfactors_scalar_packed_buf[32*i:32*(i+1)] = factor_scalar_packed

            factor = ed25519.Point.B_times(factor_scalar)
            cfactors_packed_buf[32*i:32*(i+1)] = factor.pack()

        cerror_codes = ristretto.ffi.new("int[]", N)
        ristretto.lib.group_ges_unpack(cfactors, cfactors_packed, 
                cerror_codes, N)

        cfactors_scalar = ristretto.ffi.new("group_scalar[]", N)

        ristretto.lib.group_scalars_unpack(cfactors_scalar,
                cfactors_scalar_packed, cerror_codes, N)

        product = ed25519.Point.B_times(product_scalar)
        cproduct = point_to_c(product)
        
        cpp = ristretto.ffi.new("product_proof*")
        cpp.number_of_factors = N

        cpartial_products = ristretto.ffi.new("unsigned char[]", 32*max(N-2,0))
        cpp.partial_products = cpartial_products
        # if we don't use the variable "cpartial_products" to keep
        # the python object returned by ffi.new(...) alive, the memory
        # will be freed immediately (and possibly overriden in the next
        # lines leading to strange bugs).

        cdht_proofs = ristretto.ffi.new("unsigned char[]", 96*max(N-1,0))
        cpp.dht_proofs = cdht_proofs

        ristretto.lib.product_proof_create(cpp, cfactors_scalar,
                cfactors_scalar_packed, cfactors_packed)

        pp, _, _ = schnorr.ProductProof.create(factors_scalar)
        
        dht_proofs_buf = ristretto.ffi.buffer(cpp.dht_proofs, 96*max(N-1,0))
        partial_products_buf = ristretto.ffi.buffer(
                cpp.partial_products, 32*max(N-2,0))

        for i in range(max(N-1,0)):
            self.assertEqual(pp._dht_proofs[i].pack(),
                    dht_proofs_buf[96*i:96*(i+1)])
            if i<N-2:
                self.assertEqual(pp._partial_products[i].pack(),
                        partial_products_buf[32*i:32*(i+1)])

        self.assertTrue(ristretto.lib.product_proof_is_valid_for(cpp,
                cfactors, cfactors_packed, cproduct, product.pack()))


        # wrong product
        R = ed25519.Point.random()
        R_packed = R.pack()
        cR = point_to_c(R)

        self.assertFalse(ristretto.lib.product_proof_is_valid_for(cpp,
                cfactors, cfactors_packed, cR, R_packed))


        # wrong factor
        i = random.randint(0,N-1)

        cT = ristretto.ffi.new("group_ge*")
        cT[0] = cfactors[i]
        T_packed = cfactors_packed_buf[i*32:(i+1)*32]

        cfactors[i] = cR[0]
        cfactors_packed_buf[i*32:(i+1)*32] = R_packed

        self.assertFalse(ristretto.lib.product_proof_is_valid_for(cpp,
                cfactors, cfactors_packed, cproduct, product.pack()))

        cfactors[i] = cT[0]
        cfactors_packed_buf[i*32:(i+1)*32] = T_packed
        
        # wrong partial product
        i = random.randint(0,N-3)

        T_packed = partial_products_buf[i*32:(i+1)*32]
        partial_products_buf[i*32:(i+1)*32] = os.urandom(32)
        self.assertFalse(ristretto.lib.product_proof_is_valid_for(cpp,
                cfactors, cfactors_packed, cproduct, product.pack()))
        partial_products_buf[i*32:(i+1)*32] = T_packed

        # wrong dht proof
        i = random.randint(0,N-2)

        T_packed = dht_proofs_buf[i*96:(i+1)*96]
        dht_proofs_buf[i*96:(i+1)*96] = os.urandom(96)
        self.assertFalse(ristretto.lib.product_proof_is_valid_for(cpp,
                cfactors, cfactors_packed, cproduct, product.pack()))
        dht_proofs_buf[i*96:(i+1)*96] = T_packed
Exemplo n.º 15
0
    def rsk(self, pseudonyms, k, s, rs):
        n = len(pseudonyms)
        assert (n == len(rs))

        if n == 0:
            return

        # fill cs and ck
        cs = ristretto.ffi.new("group_scalar*")
        ck = ristretto.ffi.new("group_scalar*")

        assert (ristretto.lib.group_scalar_unpack(cs,
                                                  ed25519.scalar_pack(s)) == 0)

        assert (ristretto.lib.group_scalar_unpack(ck,
                                                  ed25519.scalar_pack(k)) == 0)

        # fill crs
        crs = ristretto.ffi.new("group_scalar[]", n)
        cerror_codes = ristretto.ffi.new("int[]", n)

        ristretto.lib.group_scalars_unpack(
            crs, b''.join([ed25519.scalar_pack(rs[i]) for i in range(n)]),
            cerror_codes, n)

        for i in range(n):
            assert (cerror_codes[i] == 0)

        ccores = ristretto.ffi.new("group_ge[]", n)
        cblindings = ristretto.ffi.new("group_ge[]", n)
        ctarget = ristretto.ffi.new("group_ge*")

        packed_target = None

        cbuf = ristretto.ffi.new("unsigned char[]", 32 * n)
        buf = ristretto.ffi.buffer(cbuf)

        for i in range(n):
            ristretto.ffi.memmove(cbuf + 32 * i, pseudonyms[i].data[0:32], 32)

        ristretto.lib.group_ges_unpack(cblindings, cbuf, cerror_codes, n)

        for i in range(n):
            if cerror_codes[i] != 0:
                raise InvalidArgument(f"couldn't unpack the {i}th triples' "
                                      "blinding")

        for i in range(n):
            ristretto.ffi.memmove(cbuf + 32 * i, pseudonyms[i].data[32:64], 32)

        ristretto.lib.group_ges_unpack(ccores, cbuf, cerror_codes, n)

        for i in range(n):
            if cerror_codes[i] != 0:
                raise InvalidArgument(f"couldn't unpack the {i}th triples' "
                                      "core")

        for i in range(n):
            if i == 0:
                packed_target = pseudonyms[i].data[64:96]
            elif pseudonyms[i].data[64:96] != packed_target:
                raise InvalidArgument(
                    "the triples' targets are not all "
                    f"the same; the {i}th triple's target differs "
                    "from the target of the first triple")

        if ristretto.lib.group_ge_unpack(ctarget, packed_target) != 0:
            raise InvalidArgument("couldn't unpack the target of the first "
                                  "triple")

        if ristretto.lib.group_ge_isneutral(ctarget) != 0:
            # ctarget==0
            raise InvalidArgument("target can't be zero")

        ristretto.lib.elgamal_triples_rsk(cblindings, ccores, cblindings,
                                          ccores, ctarget, ctarget, ck, cs,
                                          crs, n)

        cblindingbuf = ristretto.ffi.new("unsigned char[]", 32 * n)
        blindingbuf = ristretto.ffi.buffer(cblindingbuf)
        ccorebuf = ristretto.ffi.new("unsigned char[]", 32 * n)
        corebuf = ristretto.ffi.buffer(ccorebuf)
        ctargetbuf = ristretto.ffi.new("unsigned char[]", 32)
        targetbuf = ristretto.ffi.buffer(ctargetbuf)

        ristretto.lib.group_ge_pack(ctargetbuf, ctarget)

        ristretto.lib.group_ges_pack(cblindingbuf, cblindings, n)
        ristretto.lib.group_ges_pack(ccorebuf, ccores, n)

        for i in range(n):
            pseudonyms[i].data = b''.join(
                (blindingbuf[32 * i:32 * (i + 1)],
                 corebuf[32 * i:32 * (i + 1)], targetbuf[:]))