def decrypt(dk: DecryptionKey, ciphertext: Ciphertext, dtype: tf.DType = tf.int32) -> tf.Tensor: c = ciphertext.raw gxd = tf_big.pow(c, dk.d1, dk.nn) xd = (gxd - 1) // dk.n x = (xd * dk.d2) % dk.n if dtype == tf.variant: return x return tf_big.export_tensor(x, dtype=dtype)
def mul( ek: EncryptionKey, lhs: Ciphertext, rhs: tf.Tensor, do_refresh: bool = True, ) -> Ciphertext: c = lhs.raw k = tf_big.import_tensor(rhs) d = tf_big.pow(c, k) % ek.nn res = Ciphertext(ek, d) if not do_refresh: return res return refresh(ek, res)
def encrypt( ek: EncryptionKey, plaintext: tf.Tensor, randomness: Optional[Randomness] = None, ) -> Ciphertext: x = tf_big.import_tensor(plaintext) randomness = randomness or gen_randomness(ek=ek, shape=x.shape) r = randomness.raw assert r.shape == x.shape gx = 1 + (ek.n * x) % ek.nn rn = tf_big.pow(r, ek.n, ek.nn) c = gx * rn % ek.nn return Ciphertext(ek, c)
def refresh(ek: EncryptionKey, ciphertext: Ciphertext) -> Ciphertext: c = ciphertext.raw s = gen_randomness(ek=ek, shape=c.shape).raw sn = tf_big.pow(s, ek.n, ek.nn) d = (c * sn) % ek.nn return Ciphertext(ek, d)