class EcdhChaskeyCrypto(EcdhBase): params = CryptoParams('ecdh-chaskey') def _enc(self, plain: int, my_sk: int, target_pk: int) -> Tuple[List[int], None]: # Compute shared key key = self._ecdh_sha256(target_pk, my_sk) plain_bytes = plain.to_bytes(32, byteorder='big') # Call java implementation iv = secrets.token_bytes(16) iv_cipher, _ = run_command(['java', '-Xms4096m', '-Xmx16384m', '-cp', f'{circuit_builder_jar}', 'zkay.ChaskeyLtsCbc', 'enc', key.hex(), iv.hex(), plain_bytes.hex()]) iv_cipher = iv + int(iv_cipher.splitlines()[-1], 16).to_bytes(32, byteorder='big') return self.pack_byte_array(iv_cipher, self.params.cipher_chunk_size), None def _dec(self, cipher: Tuple[int, ...], my_sk: Any) -> Tuple[int, None]: # Extract sender address from cipher metadata and request corresponding public key sender_pk = cipher[-1] cipher = cipher[:-1] assert len(cipher) == self.params.cipher_payload_len # Compute shared key key = self._ecdh_sha256(sender_pk, my_sk) # Call java implementation iv_cipher = self.unpack_to_byte_array(cipher, self.params.cipher_chunk_size, self.params.cipher_bytes_payload) iv, cipher_bytes = iv_cipher[:16], iv_cipher[16:] plain, _ = run_command(['java', '-Xms4096m', '-Xmx16384m', '-cp', f'{circuit_builder_jar}', 'zkay.ChaskeyLtsCbc', 'dec', key.hex(), iv.hex(), cipher_bytes.hex()]) plain = int(plain.splitlines()[-1], 16) return plain, None
def __get_decrypted_retval(self, raw_value, is_cipher, crypto_params_name, constructor): return self.dec(CipherValue(raw_value, params=CryptoParams(crypto_params_name)), constructor, crypto_backend=crypto_params_name )[0] if is_cipher else constructor(raw_value)
class RSAOAEPCrypto(RSACrypto): params = CryptoParams('rsa-oaep') def _enc(self, plain: int, _: int, target_pk: int) -> Tuple[List[int], List[int]]: pub_key = RSA.construct((target_pk, self.default_exponent)) encrypt = PersistentLocals( PKCS1_OAEP.new(pub_key, hashAlgo=SHA256).encrypt) cipher_bytes = encrypt(plain.to_bytes(32, byteorder='big')) cipher = self.pack_byte_array(cipher_bytes, self.params.cipher_chunk_size) rnd_bytes = encrypt.locals['ros'] rnd = self.pack_byte_array(rnd_bytes, self.params.rnd_chunk_size) return cipher, rnd def _dec(self, cipher: Tuple[int, ...], sk: RSA.RsaKey) -> Tuple[int, List[int]]: decrypt = PersistentLocals(PKCS1_OAEP.new(sk, hashAlgo=SHA256).decrypt) cipher_bytes = self.unpack_to_byte_array( cipher, self.params.cipher_chunk_size, self.params.cipher_bytes_payload) plain = int.from_bytes(decrypt(cipher_bytes), byteorder='big') rnd_bytes = decrypt.locals['seed'] rnd = self.pack_byte_array(rnd_bytes, self.params.rnd_chunk_size) return plain, rnd
class RSAPKCS15Crypto(RSACrypto): params = CryptoParams('rsa-pkcs1.5') def _enc(self, plain: int, _: int, target_pk: int) -> Tuple[List[int], List[int]]: pub_key = RSA.construct((target_pk, self.default_exponent)) encrypt = PersistentLocals(PKCS1_v1_5.new(pub_key).encrypt) cipher_bytes = encrypt(plain.to_bytes(32, byteorder='big')) cipher = self.pack_byte_array(cipher_bytes, self.params.cipher_chunk_size) rnd_bytes = encrypt.locals['ps'] assert len(rnd_bytes) == self.params.rnd_bytes rnd = self.pack_byte_array(rnd_bytes, self.params.rnd_chunk_size) return cipher, rnd def _dec(self, cipher: Tuple[int, ...], sk: RSA.RsaKey) -> Tuple[int, List[int]]: decrypt = PersistentLocals(PKCS1_v1_5.new(sk).decrypt) cipher_bytes = self.unpack_to_byte_array( cipher, self.params.cipher_chunk_size, self.params.cipher_bytes_payload) ret = decrypt(cipher_bytes, None) if ret is None: raise RuntimeError("Tried to decrypt invalid cipher text") plain = int.from_bytes(ret, byteorder='big') rnd_bytes = decrypt.locals['em'][2:decrypt.locals['sep']] assert len(rnd_bytes) == self.params.rnd_bytes rnd = self.pack_byte_array(rnd_bytes, self.params.rnd_chunk_size) return plain, rnd
def all_crypto_params(self) -> List[CryptoParams]: crypto_backends = list( dict.fromkeys( [self._get_crypto_backend(hom) for hom in Homomorphism])) return [ CryptoParams(backend) for backend in crypto_backends if backend is not None ]
def get_params(params: CryptoParams = None, crypto_backend: str = None) -> CryptoParams: from zkay.config import cfg if params is not None: return params elif crypto_backend is not None: return CryptoParams(crypto_backend) else: return cfg.get_crypto_params(Homomorphism.NON_HOMOMORPHIC)
def decl(self, name, constructor: Callable = lambda x: x, *, cipher: bool = False, crypto_backend: str = cfg.main_crypto_backend): """Define the wrapper constructor for a state variable.""" assert name not in self.__constructors self.__constructors[name] = (cipher, CryptoParams(crypto_backend), constructor)
class DummyHomCrypto(ZkayHomomorphicCryptoInterface): params = CryptoParams('dummy-hom') def _generate_or_load_key_pair(self, address: str) -> KeyPair: seed = int(address, 16) rng = Random(seed) def rand_bytes(n: int) -> bytes: return bytes([rng.randrange(256) for _ in range(n)]) pk = int( generate_probable_prime(exact_bits=self.params.key_bits, randfunc=rand_bytes)) return KeyPair( PublicKeyValue(self.serialize_pk(pk, self.params.key_bytes), params=self.params), PrivateKeyValue(pk)) def _enc(self, plain: int, _: int, target_pk: int): plain = plain % bn128_scalar_field # handle negative values cipher = (plain * target_pk + 1) % bn128_scalar_field return [cipher], list(RandomnessValue(params=self.params)[:]) def _dec(self, cipher: Tuple[int, ...], sk: int) -> Tuple[int, List[int]]: key_inv = pow(sk, -1, bn128_scalar_field) plain = ((cipher[0] - 1) * key_inv) % bn128_scalar_field if plain > bn128_scalar_field // 2: plain = plain - bn128_scalar_field return plain, list(RandomnessValue(params=self.params)[:]) def do_op(self, op: str, public_key: Union[List[int], int], *args: Union[CipherValue, int]) -> List[int]: def deserialize(operand: Union[CipherValue, int]) -> int: if isinstance(operand, CipherValue): val = operand[0] return val - 1 if val != 0 else 0 else: return operand operands = [deserialize(arg) for arg in args] if op == 'sign-': result = -operands[0] elif op == '+': result = operands[0] + operands[1] elif op == '-': result = operands[0] - operands[1] elif op == '*': result = operands[0] * operands[1] else: raise ValueError(f'Unsupported operation {op}') return [(result + 1) % bn128_scalar_field] def do_rerand(self, arg: CipherValue, public_key: List[int]) -> Tuple[List[int], List[int]]: return arg, [0]
def do_rerand(self, arg: CipherValue, crypto_backend: str, target_addr: AddressValue, data: Dict, rnd_key: str): """ Re-randomizes arg using fresh randomness, which is stored in data[rnd_key] (side-effect!) """ params = CryptoParams(crypto_backend) pk = self.__keystore[params.crypto_name].getPk(target_addr) crypto_inst = self.__crypto[params.crypto_name] assert isinstance(crypto_inst, ZkayHomomorphicCryptoInterface) result, rand = crypto_inst.do_rerand(arg, pk[:]) data[rnd_key] = RandomnessValue(rand, params=params) # store randomness return CipherValue(result, params=params)
def _ensure_encryption(self, stmt: Statement, plain: HybridArgumentIdf, new_privacy: PrivacyLabelExpr, crypto_params: CryptoParams, cipher: HybridArgumentIdf, is_param: bool, is_dec: bool): """ Make sure that cipher = enc(plain, getPk(new_privacy), priv_user_provided_rnd). This automatically requests necessary keys and adds a circuit input for the randomness. Note: This function adds pre-statements to stmt :param stmt [SIDE EFFECT]: the statement which contains the expression which requires this encryption :param plain: circuit variable referencing the plaintext value :param new_privacy: privacy label corresponding to the destination key address :param cipher: circuit variable referencing the encrypted value :param is_param: whether cipher is a function parameter :param is_dec: whether this is a decryption operation (user supplied plain) as opposed to an encryption operation (user supplied cipher) """ if crypto_params.is_symmetric_cipher(): # Need a different set of keys for hybrid-encryption (ecdh-based) backends self._require_secret_key(crypto_params) my_pk = self._require_public_key_for_label_at( stmt, Expression.me_expr(), crypto_params) if is_dec: other_pk = self._get_public_key_in_sender_field( stmt, cipher, crypto_params) else: if new_privacy == Expression.me_expr(): other_pk = my_pk else: other_pk = self._require_public_key_for_label_at( stmt, new_privacy, crypto_params) self.phi.append( CircComment( f'{cipher.name} = enc({plain.name}, ecdh({other_pk.name}, my_sk))' )) self._phi.append( CircSymmEncConstraint(plain, other_pk, cipher, is_dec)) else: rnd = self._secret_input_name_factory.add_idf( f'{plain.name if is_param else cipher.name}_R', TypeName.rnd_type(crypto_params)) pk = self._require_public_key_for_label_at(stmt, new_privacy, crypto_params) if not is_dec: self.phi.append( CircComment( f'{cipher.name} = enc({plain.name}, {pk.name})')) self._phi.append(CircEncConstraint(plain, rnd, pk, cipher, is_dec))
def do_homomorphic_op(self, op: str, crypto_backend: str, target_addr: AddressValue, *args: Union[CipherValue, int]): params = CryptoParams(crypto_backend) pk = self.__keystore[params.crypto_name].getPk(target_addr) for arg in args: if isinstance(arg, CipherValue ) and params.crypto_name != arg.params.crypto_name: raise ValueError( 'CipherValues from different crypto backends used in homomorphic operation' ) crypto_inst = self.__crypto[params.crypto_name] assert isinstance(crypto_inst, ZkayHomomorphicCryptoInterface) result = crypto_inst.do_op(op, pk[:], *args) return CipherValue(result, params=params)
class DummyCrypto(ZkayCryptoInterface): params = CryptoParams('dummy') def _generate_or_load_key_pair(self, address: str) -> KeyPair: aint = int(address, 16) return KeyPair( PublicKeyValue(self.serialize_pk(aint, self.params.key_bytes), params=self.params), PrivateKeyValue(aint)) def _enc(self, plain: int, _: int, target_pk: int): cipher = (plain + target_pk) % bn128_scalar_field return [cipher] * self.params.cipher_payload_len, list( RandomnessValue(params=self.params)[:]) def _dec(self, cipher: Tuple[int, ...], sk: int) -> Tuple[int, List[int]]: plain = (cipher[0] - sk) % bn128_scalar_field return plain, list(RandomnessValue(params=self.params)[:])
class EcdhAesCrypto(EcdhBase): params = CryptoParams('ecdh-aes') def _enc(self, plain: int, my_sk: int, target_pk: int) -> Tuple[List[int], None]: key = self._ecdh_sha256(target_pk, my_sk) plain_bytes = plain.to_bytes(32, byteorder='big') # Encrypt and extract iv cipher = AES.new(key, AES.MODE_CBC) cipher_bytes = cipher.encrypt(plain_bytes) iv = cipher.iv # Pack iv and cipher iv_cipher = b''.join([iv, cipher_bytes]) return self.pack_byte_array(iv_cipher, self.params.cipher_chunk_size), None def _dec(self, cipher: Tuple[int, ...], my_sk: Any) -> Tuple[int, None]: # Extract sender address from cipher metadata and request corresponding public key sender_pk = cipher[-1] cipher = cipher[:-1] assert len(cipher) == self.params.cipher_payload_len # Compute shared key key = self._ecdh_sha256(sender_pk, my_sk) # Unpack iv and cipher iv_cipher = self.unpack_to_byte_array(cipher, self.params.cipher_chunk_size, self.params.cipher_bytes_payload) iv, cipher_bytes = iv_cipher[:16], iv_cipher[16:] # Decrypt cipher = AES.new(key, AES.MODE_CBC, iv=iv) plain_bytes = cipher.decrypt(cipher_bytes) plain = int.from_bytes(plain_bytes, byteorder='big') return plain, None
class PaillierCrypto(ZkayHomomorphicCryptoInterface): params = CryptoParams('paillier') def _generate_or_load_key_pair(self, address: str) -> KeyPair: key_file = os.path.join(cfg.data_dir, 'keys', f'paillier_{self.params.key_bits}_{address}.bin') os.makedirs(os.path.dirname(key_file), exist_ok=True) if not os.path.exists(key_file): zk_print(f'Key pair not found, generating new Paillier secret...') pk, sk = self._generate_key_pair() self._write_key_pair(key_file, pk, sk) zk_print('Done') else: # Restore saved key pair zk_print(f'Paillier secret found, loading from file {key_file}') pk, sk = self._read_key_pair(key_file) return KeyPair(PublicKeyValue(pk, params=self.params), PrivateKeyValue(sk)) def _write_key_pair(self, key_file: str, pk: List[int], sk: List[int]): with open(key_file, 'wb') as f: f.write(len(pk).to_bytes(4, byteorder='big')) for p in pk: f.write(p.to_bytes(self.params.cipher_chunk_size, byteorder='big')) f.write(len(sk).to_bytes(4, byteorder='big')) for s in sk: f.write(s.to_bytes(self.params.cipher_chunk_size, byteorder='big')) def _read_key_pair(self, key_file: str) -> Tuple[List[int], List[int]]: pk = [] sk = [] with open(key_file, 'rb') as f: pk_len = int.from_bytes(f.read(4), byteorder='big') for _ in range(pk_len): pk.append(int.from_bytes(f.read(self.params.cipher_chunk_size), byteorder='big')) sk_len = int.from_bytes(f.read(4), byteorder='big') for _ in range(sk_len): sk.append(int.from_bytes(f.read(self.params.cipher_chunk_size), byteorder='big')) return pk, sk def _generate_key_pair(self) -> Tuple[List[int], List[int]]: n_bits = self.params.key_bits pq_bits = (n_bits + 1) // 2 while True: p = int(generate_probable_prime(exact_bits=pq_bits)) q = int(generate_probable_prime(exact_bits=pq_bits)) n = p * q if p != q and n.bit_length() == n_bits: break n_chunks = self.serialize_pk(n, self.params.key_bytes) p_chunks = self.serialize_pk(p, self.params.key_bytes) q_chunks = self.serialize_pk(q, self.params.key_bytes) return n_chunks, p_chunks + q_chunks @staticmethod def sample_below(n: int, co_prime: bool = False): while True: random = randrange(n) if not co_prime or (gcd(random, n) == 1): return random def _enc_with_rand(self, plain: int, random: int, n: int) -> List[int]: n_sqr = n * n g_pow_plain = n * plain + 1 rand_pow_n = pow(random, n, n_sqr) cipher = (g_pow_plain * rand_pow_n) % n_sqr return self.serialize_pk(cipher, self.params.cipher_bytes_payload) def _enc(self, plain: int, _: int, target_pk: int) -> Tuple[List[int], List[int]]: n = target_pk plain = plain % n # handle negative numbers random = self.sample_below(n, co_prime=True) cipher_chunks = self._enc_with_rand(plain, random, n) random_chunks = self.serialize_pk(random, self.params.rnd_bytes) return cipher_chunks, random_chunks def _dec(self, cipher: Tuple[int, ...], sk: Any) -> Tuple[int, List[int]]: p = self.deserialize_pk(sk[:self.params.key_len]) q = self.deserialize_pk(sk[self.params.key_len:]) n = p * q n_sqr = n * n lambda_ = (p - 1) * (q - 1) lambda_inv = pow(lambda_, -1, n) c = self.deserialize_pk(cipher) # Compute the plaintext: plain = L(cipher^lambda mod n^2) / lambda mod n c_pow_lambda = pow(c, lambda_, n_sqr) l = (c_pow_lambda - 1) // n plain = (l * lambda_inv) % n # Compute the randomness that was used # Fortunately, this has been asked and answered on stackexchange: https://math.stackexchange.com/a/114142 generator = n + 1 g_pow_plain_inv = pow(generator, -plain, n_sqr) rand_pow_n = (c * g_pow_plain_inv) % n_sqr p_inv = pow(p, -1, q - 1) # Inverse of p modulo q-1 q_inv = pow(q, -1, p - 1) # Inverse of q modulo p-1 c_pow_p_inv = pow(rand_pow_n, p_inv, q) c_pow_q_inv = pow(rand_pow_n, q_inv, p) # random == c_pow_q_inv mod p # random == c_pow_p_inv mod q # Compute random using the Chinese Remainder Theorem y_1 = pow(q, -1, p) y_2 = pow(p, -1, q) w_1 = (y_1 * q) % n w_2 = (y_2 * p) % n random = (c_pow_q_inv * w_1 + c_pow_p_inv * w_2) % n random_chunks = self.serialize_pk(random, self.params.rnd_bytes) # Handle possible negative plaintexts if plain > n // 2: plain = plain - n return plain, random_chunks def do_op(self, op: str, public_key: Union[List[int], int], *args: Union[CipherValue, int]) -> List[int]: n = self.deserialize_pk(public_key) n_sqr = n * n def deserialize(operand: Union[CipherValue, int]) -> int: if isinstance(operand, CipherValue): val = self.deserialize_pk(operand[:]) return val if val != 0 else 1 # If ciphertext is 0, return 1 == Enc(0, 0) else: return operand # Return plaintext arguments as-is operands = [deserialize(arg) for arg in args] if op == 'sign-': assert isinstance(args[0], CipherValue) result = pow(operands[0], -1, n_sqr) elif op == '+': assert isinstance(args[0], CipherValue) and isinstance(args[1], CipherValue) result = (operands[0] * operands[1]) % n_sqr elif op == '-': assert isinstance(args[0], CipherValue) and isinstance(args[1], CipherValue) result = (operands[0] * pow(operands[1], -1, n_sqr)) % n_sqr elif op == '*' and isinstance(args[1], int): assert isinstance(args[0], CipherValue) result = pow(operands[0], operands[1], n_sqr) elif op == '*' and isinstance(args[0], int): assert isinstance(args[1], CipherValue) result = pow(operands[1], operands[0], n_sqr) else: raise ValueError(f'Unsupported operation {op}') return self.serialize_pk(result, self.params.cipher_bytes_payload) def do_rerand(self, arg: CipherValue, public_key: List[int]) -> Tuple[List[int], List[int]]: raise NotImplementedError("Rerandomization not implemented for Paillier backend")
class ElgamalCrypto(ZkayHomomorphicCryptoInterface): params = CryptoParams('elgamal') def _generate_or_load_key_pair(self, address: str) -> KeyPair: key_file = os.path.join( cfg.data_dir, 'keys', f'elgamal_{self.params.key_bits}_{address}.bin') os.makedirs(os.path.dirname(key_file), exist_ok=True) if not os.path.exists(key_file): zk_print(f'Key pair not found, generating new ElGamal secret...') pk, sk = self._generate_key_pair() self._write_key_pair(key_file, pk, sk) zk_print('Done') else: # Restore saved key pair zk_print(f'ElGamal secret found, loading from file {key_file}') pk, sk = self._read_key_pair(key_file) return KeyPair(PublicKeyValue(pk, params=self.params), PrivateKeyValue(sk)) def _write_key_pair(self, key_file: str, pk: List[int], sk: int): with open(key_file, 'wb') as f: for p in pk: f.write( p.to_bytes(self.params.cipher_chunk_size, byteorder='big')) f.write(sk.to_bytes(self.params.cipher_chunk_size, byteorder='big')) def _read_key_pair(self, key_file: str) -> Tuple[List[int], int]: with open(key_file, 'rb') as f: pkx = int.from_bytes(f.read(self.params.cipher_chunk_size), byteorder='big') pky = int.from_bytes(f.read(self.params.cipher_chunk_size), byteorder='big') sk = int.from_bytes(f.read(self.params.cipher_chunk_size), byteorder='big') return [pkx, pky], sk def _generate_key_pair(self) -> Tuple[List[int], int]: sk = randrange(babyjubjub.CURVE_ORDER) pk = babyjubjub.Point.GENERATOR * babyjubjub.Fr(sk) return [pk.u.s, pk.v.s], sk def _enc(self, plain: int, _: int, target_pk: int) -> Tuple[List[int], List[int]]: pk = self.serialize_pk(target_pk, self.params.key_bytes) r = randrange(babyjubjub.CURVE_ORDER) cipher_chunks = self._enc_with_rand(plain, r, pk) return cipher_chunks, [r] def _dec(self, cipher: Tuple[int, ...], sk: Any) -> Tuple[int, List[int]]: with time_measure("elgamal_decrypt"): c1 = babyjubjub.Point(babyjubjub.Fq(cipher[0]), babyjubjub.Fq(cipher[1])) c2 = babyjubjub.Point(babyjubjub.Fq(cipher[2]), babyjubjub.Fq(cipher[3])) shared_secret = c1 * babyjubjub.Fr(sk) plain_embedded = c2 + shared_secret.negate() plain = self._de_embed(plain_embedded) # TODO randomness misused for the secret key, which is an extremely ugly hack... return plain, [sk] def _de_embed(self, plain_embedded: babyjubjub.Point) -> int: # handle basic special cases without expensive discrete log computation if plain_embedded == babyjubjub.Point.ZERO: return 0 if plain_embedded == babyjubjub.Point.GENERATOR: return 1 return get_dlog(plain_embedded.u.s, plain_embedded.v.s) def do_op(self, op: str, public_key: List[int], *args: Union[CipherValue, int]) -> List[int]: def deserialize( operand: Union[CipherValue, int] ) -> Union[Tuple[babyjubjub.Point, babyjubjub.Point], int]: if isinstance(operand, CipherValue): # if ciphertext is 0, return (Point.ZERO, Point.ZERO) == Enc(0, 0) if operand == CipherValue([0] * 4, params=operand.params): return babyjubjub.Point.ZERO, babyjubjub.Point.ZERO else: c1 = babyjubjub.Point(babyjubjub.Fq(operand[0]), babyjubjub.Fq(operand[1])) c2 = babyjubjub.Point(babyjubjub.Fq(operand[2]), babyjubjub.Fq(operand[3])) return c1, c2 else: return operand args = [deserialize(arg) for arg in args] if op == '+': e1 = args[0][0] + args[1][0] e2 = args[0][1] + args[1][1] elif op == '-': e1 = args[0][0] + args[1][0].negate() e2 = args[0][1] + args[1][1].negate() elif op == '*' and isinstance(args[1], int): e1 = args[0][0] * babyjubjub.Fr(args[1]) e2 = args[0][1] * babyjubjub.Fr(args[1]) elif op == '*' and isinstance(args[0], int): e1 = args[1][0] * babyjubjub.Fr(args[0]) e2 = args[1][1] * babyjubjub.Fr(args[0]) else: raise ValueError(f'Unsupported operation {op}') return [e1.u.s, e1.v.s, e2.u.s, e2.v.s] def do_rerand(self, arg: CipherValue, public_key: List[int]) -> Tuple[List[int], List[int]]: # homomorphically add encryption of zero to re-randomize r = randrange(babyjubjub.CURVE_ORDER) enc_zero = CipherValue(self._enc_with_rand(0, r, public_key), params=arg.params) return self.do_op('+', public_key, arg, enc_zero), [r] def _enc_with_rand(self, plain: int, random: int, pk: List[int]) -> List[int]: plain_embedded = babyjubjub.Point.GENERATOR * babyjubjub.Fr(plain) shared_secret = babyjubjub.Point(babyjubjub.Fq( pk[0]), babyjubjub.Fq(pk[1])) * babyjubjub.Fr(random) c1 = babyjubjub.Point.GENERATOR * babyjubjub.Fr(random) c2 = plain_embedded + shared_secret return [c1.u.s, c1.v.s, c2.u.s, c2.v.s]
def get_crypto_params(self, hom: Homomorphism) -> CryptoParams: backend_name = self._get_crypto_backend(hom) if backend_name is None: raise ValueError( f'No crypto backend set for homomorphism {hom.name}') return CryptoParams(backend_name)