class Intersection(NonInteractiveSetIntersection): """Two-Client Set Intersect scheme""" client_count = 2 def __init__(self, curve=prime256v1): super().__init__() self.group = ECGroup(curve) self.g = self.group.random(G) def setup(self, secpar): """Generate the clients' keys""" self.secpar = secpar sigma = self.group.random(ZR) msk = os.urandom(secpar // 8) return ((msk, sigma), (msk, 1 - sigma)) def _phi(self, cipher, pt): """PRF mapping pt to a group element""" padding_len = 16 - (len(pt) % 16) pt = pt + b'\0' * padding_len encryptor = cipher.encryptor() ct = encryptor.update(pt) + encryptor.finalize() exponent = int.from_bytes(ct, 'big') % self.group.order() return self.g**exponent def _H(self, g): """Mapping of g to bytes Can be used to map a group element g to an AE key.""" prefix = b'\x00' hashable = prefix + self.group.serialize(g) h = hashlib.sha256(hashable).digest() return h[:16] def _H_bytes(self, g): """Mapping of g to bytes Can be used to map a group element g to an AE nonce.""" prefix = b'\x01' hashable = prefix + self.group.serialize(g) h = hashlib.sha256(hashable).digest() return h def encrypt(self, usk, gid, pt_set): """Encrypt a plaintext set under a gid using usk Returns a dict of ciphertexts.""" msk, sigma = usk iv = gid cipher = Cipher(algorithms.AES(msk), modes.CBC(iv), backend=default_backend()) ct_dict = {} for pt in pt_set: k = self._phi(cipher, pt) ct1 = self.group.serialize(k**sigma) # use deterministic authenticated encryption ae_key = self._H(k) ae_nonce = self._H_bytes(k)[:12] ae = AESGCM(ae_key) ct2 = (ae_nonce, ae.encrypt(ae_nonce, pt, None)) ct_dict[ae_key] = (ct1, ct2) return ct_dict def eval(self, ct_sets): """Evaluates the ciphertexts for determining the cardinality of the set intersection Expects two dicts of ciphertexts.""" pt_intersection = set() ct_intersection = ct_sets[0].keys() & ct_sets[1].keys() for k in ct_intersection: g1 = self.group.deserialize(ct_sets[0][k][0]) g2 = self.group.deserialize(ct_sets[1][k][0]) key = g1 * g2 # decrypt using ct_sets[0] ae_nonce, ct = ct_sets[0][k][1] ae_key = self._H(key) ae = AESGCM(ae_key) pt = ae.decrypt(ae_nonce, ct, None) pt_intersection.add(pt) return pt_intersection