def generate_session_key(idA, idB, Ra, Rb, D, x, master_public, entity, l): P1, P2, Ppub, g = master_public if entity == 'A': R = Rb elif entity == 'B': R = Ra else: raise Exception('Invalid entity') g1 = ate.pairing(R, D) g2 = g**x g3 = g1**x if (entity == 'B'): (g1, g2) = (g2, g1) uidA = sm3_hash(str2hexbytes(idA)) uidB = sm3_hash(str2hexbytes(idB)) kdf_input = uidA + uidB kdf_input += ec2sp(Ra) + ec2sp(Rb) kdf_input += fe2sp(g1) + fe2sp(g2) + fe2sp(g3) sk = sm3_kdf(kdf_input.encode('utf-8'), l) return sk
def kem_decap(master_public, identity, D, C1, l): if ec.is_on_curve(C1, ec.b2) == False: return FAILURE t = ate.pairing(C1, D) uid = sm3_hash(str2hexbytes(identity)) kdf_input = ec2sp(C1) + fe2sp(t) + uid k = sm3_kdf(kdf_input.encode('utf-8'), l) return k
def kem_encap(master_public, identity, l): P1, P2, Ppub, g = master_public Q = public_key_extract('encrypt', master_public, identity) rand_gen = SystemRandom() x = rand_gen.randrange(ec.curve_order) C1 = ec.multiply(Q, x) t = g**x uid = sm3_hash(str2hexbytes(identity)) kdf_input = ec2sp(C1) + fe2sp(t) + uid k = sm3_kdf(kdf_input.encode('utf-8'), l) return (k, C1)
def encrypt(self, data, pubkey): msg = data.hex() k = random_hex(self.para_len) C1 = self._kg(int(k, 16), self.ecc_table['g']) xy = self._kg(int(k, 16), pubkey) x2 = xy[0:self.para_len] y2 = xy[self.para_len:2 * self.para_len] ml = len(msg) t = sm3.sm3_kdf(xy.encode('utf8'), ml / 2) if int(t, 16) == 0: return None else: form = '%%0%dx' % ml C2 = form % (int(msg, 16) ^ int(t, 16)) C3 = sm3.sm3_hash( [i for i in bytes.fromhex('%s%s%s' % (x2, msg, y2))]) return bytes.fromhex('%s%s%s' % (C1, C3, C2))
def decrypt(self, data, privKey): data = data.hex() len_2 = 2 * self.para_len len_3 = len_2 + 64 C1 = data[0:len_2] C3 = data[len_2:len_3] C2 = data[len_3:] xy = self._kg(privKey, C1) x2 = xy[0:self.para_len] y2 = xy[self.para_len:len_2] cl = len(C2) t = sm3.sm3_kdf(xy.encode('utf8'), cl / 2) if int(t, 16) == 0: return None else: form = '%%0%dx' % cl M = form % (int(C2, 16) ^ int(t, 16)) u = sm3.sm3_hash( [i for i in bytes.fromhex('%s%s%s' % (x2, M, y2))]) return bytes.fromhex(M)
def h2rf(i, z, n): l = 8 * ceil((5 * bitlen(n)) / 32) msg = i2sp(i, 1).encode('utf-8') ha = sm3_kdf(msg + z, l) h = int(ha, 16) return (h % (n - 1)) + 1