def osswu2_help(t): assert isinstance(t, Fq2) # first, compute X0(t), detecting and handling exceptional case num_den_common = xi_2**2 * t**4 + xi_2 * t**2 x0_num = Ell2p_b * (num_den_common + Fq(q, 1)) x0_den = -Ell2p_a * num_den_common x0_den = Ell2p_a * xi_2 if x0_den == 0 else x0_den # compute num and den of g(X0(t)) gx0_den = pow(x0_den, 3) gx0_num = Ell2p_b * gx0_den gx0_num += Ell2p_a * x0_num * pow(x0_den, 2) gx0_num += pow(x0_num, 3) # try taking sqrt of g(X0(t)) # this uses the trick for combining division and sqrt from Section 5 of # Bernstein, Duif, Lange, Schwabe, and Yang, "High-speed high-security signatures." # J Crypt Eng 2(2):77--89, Sept. 2012. http://ed25519.cr.yp.to/ed25519-20110926.pdf tmp1 = pow(gx0_den, 7) # v^7 tmp2 = gx0_num * tmp1 # u v^7 tmp1 = tmp1 * tmp2 * gx0_den # u v^15 sqrt_candidate = tmp2 * pow(tmp1, (q**2 - 9) // 16) # check if g(X0(t)) is square and return the sqrt if so for root in roots_of_unity: y0 = sqrt_candidate * root if y0**2 * gx0_den == gx0_num: # found sqrt(g(X0(t))). force sign of y to equal sign of t if sgn0(y0) != sgn0(t): y0 = -y0 assert sgn0(y0) == sgn0(t) return JacobianPoint(x0_num * x0_den, y0 * pow(x0_den, 3), x0_den, False, default_ec_twist) # if we've gotten here, then g(X0(t)) is not square. convert srqt_candidate to sqrt(g(X1(t))) (x1_num, x1_den) = (xi_2 * t**2 * x0_num, x0_den) (gx1_num, gx1_den) = (xi_2**3 * t**6 * gx0_num, gx0_den) sqrt_candidate *= t**3 for eta in etas: y1 = eta * sqrt_candidate if y1**2 * gx1_den == gx1_num: # found sqrt(g(X1(t))). force sign of y to equal sign of t if sgn0(y1) != sgn0(t): y1 = -y1 assert sgn0(y1) == sgn0(t) return JacobianPoint(x1_num * x1_den, y1 * pow(x1_den, 3), x1_den, False, default_ec_twist) # if we got here, something is wrong raise RuntimeError("osswu2_help failed for unknown reasons")
def aggregate_sigs_secure(signatures, public_keys, message_hashes): """ Aggregate signatures using the secure method, which calculates exponents based on public keys, and raises each signature to an exponent before multiplying them together. This is secure against rogue public key attack, but is slower than simple aggregation. """ if (len(signatures) != len(public_keys) or len(public_keys) != len(message_hashes)): raise Exception("Invalid number of keys") mh_pub_sigs = [(message_hashes[i], public_keys[i], signatures[i]) for i in range(len(signatures))] # Sort by message hash + pk mh_pub_sigs.sort() computed_Ts = BLS.hash_pks(len(public_keys), public_keys) # Raise each sig to a power of each t, # and multiply all together into agg_sig ec = public_keys[0].ec agg_sig = JacobianPoint(Fq2.one(ec.q), Fq2.one(ec.q), Fq2.zero(ec.q), True, ec) for i, (_, _, signature) in enumerate(mh_pub_sigs): agg_sig += signature * computed_Ts[i] return Signature.from_g2(agg_sig)
def divide_by(self, divisor_signatures): """ Signature division (elliptic curve subtraction). This is useful if you have already verified parts of the tree, since verification of the resulting quotient signature will be faster (less pairings have to be perfomed). This function Divides an aggregate signature by other signatures in the aggregate trees. A signature can only be divided if it is part of the subset, and all message/public key pairs in the aggregationInfo for the divisor signature are unique. i.e you cannot divide s1 / s2, if s2 is an aggregate signature containing m1,pk1, which is also present somewhere else in s1's tree. Note, s2 itself does not have to be unique. """ message_hashes_to_remove = [] pubkeys_to_remove = [] prod = JacobianPoint(Fq2.one(default_ec.q), Fq2.one(default_ec.q), Fq2.zero(default_ec.q), True, default_ec) for divisor_sig in divisor_signatures: pks = divisor_sig.aggregation_info.public_keys message_hashes = divisor_sig.aggregation_info.message_hashes if len(pks) != len(message_hashes): raise Exception("Invalid aggregation info") for i in range(len(pks)): divisor = divisor_sig.aggregation_info.tree[ (message_hashes[i], pks[i])] try: dividend = self.aggregation_info.tree[ (message_hashes[i], pks[i])] except KeyError: raise Exception("Signature is not a subset") if i == 0: quotient = (Fq(default_ec.n, dividend) / Fq(default_ec.n, divisor)) else: # Makes sure the quotient is identical for each public # key, which means message/pk pair is unique. new_quotient = (Fq(default_ec.n, dividend) / Fq(default_ec.n, divisor)) if quotient != new_quotient: raise Exception("Cannot divide by aggregate signature," + "msg/pk pairs are not unique") message_hashes_to_remove.append(message_hashes[i]) pubkeys_to_remove.append(pks[i]) prod += (divisor_sig.value * -quotient) copy = Signature(deepcopy(self.value + prod), deepcopy(self.aggregation_info)) for i in range(len(message_hashes_to_remove)): a = message_hashes_to_remove[i] b = pubkeys_to_remove[i] if (a, b) in copy.aggregation_info.tree: del copy.aggregation_info.tree[(a, b)] sorted_keys = list(copy.aggregation_info.tree.keys()) sorted_keys.sort() copy.aggregation_info.message_hashes = [t[0] for t in sorted_keys] copy.aggregation_info.public_keys = [t[1] for t in sorted_keys] return copy
def verify(self): """ This implementation of verify has several steps. First, it reorganizes the pubkeys and messages into groups, where each group corresponds to a message. Then, it checks if the siganture has info on how it was aggregated. If so, we exponentiate each pk based on the exponent in the AggregationInfo. If not, we find public keys that share messages with others, and aggregate all of these securely (with exponents.). Finally, since each public key now corresponds to a unique message (since we grouped them), we can verify using the distinct verification procedure. """ message_hashes = self.aggregation_info.message_hashes public_keys = self.aggregation_info.public_keys assert (len(message_hashes) == len(public_keys)) hash_to_public_keys = {} for i in range(len(message_hashes)): if message_hashes[i] in hash_to_public_keys: hash_to_public_keys[message_hashes[i]].append(public_keys[i]) else: hash_to_public_keys[message_hashes[i]] = [public_keys[i]] final_message_hashes = [] final_public_keys = [] ec = public_keys[0].value.ec for message_hash, mapped_keys in hash_to_public_keys.items(): dedup = list(set(mapped_keys)) public_key_sum = JacobianPoint(Fq.one(ec.q), Fq.one(ec.q), Fq.zero(ec.q), True, ec) for public_key in dedup: try: exponent = self.aggregation_info.tree[(message_hash, public_key)] public_key_sum += (public_key.value * exponent) except KeyError: return False final_message_hashes.append(message_hash) final_public_keys.append(public_key_sum.to_affine()) mapped_hashes = [ hash_to_point_prehashed_Fq2(mh) for mh in final_message_hashes ] g1 = Fq(default_ec.n, -1) * generator_Fq() Ps = [g1] + final_public_keys Qs = [self.value.to_affine()] + mapped_hashes res = ate_pairing_multi(Ps, Qs, default_ec) return res == Fq12.one(default_ec.q)
def aggregate(public_keys, secure): """ Aggregates public keys together """ if len(public_keys) < 1: raise Exception("Invalid number of keys") public_keys.sort() computed_Ts = BLS.hash_pks(len(public_keys), public_keys) ec = public_keys[0].value.ec sum_keys = JacobianPoint(Fq.one(ec.q), Fq.one(ec.q), Fq.zero(ec.q), True, ec) for i in range(len(public_keys)): addend = public_keys[i].value if secure: addend *= computed_Ts[i] sum_keys += addend return PublicKey.from_g1(sum_keys)
def aggregate(signatures): """ Aggregates many (aggregate) signatures, using a combination of simple and secure aggregation. Signatures are grouped based on which ones share common messages, and these are all merged securely. """ public_keys = [] # List of lists message_hashes = [] # List of lists for signature in signatures: if signature.aggregation_info.empty(): raise Exception( "Each signature must have a valid aggregation " + "info") public_keys.append(signature.aggregation_info.public_keys) message_hashes.append(signature.aggregation_info.message_hashes) # Find colliding vectors, save colliding messages messages_set = set() colliding_messages_set = set() for msg_vector in message_hashes: messages_set_local = set() for msg in msg_vector: if msg in messages_set and msg not in messages_set_local: colliding_messages_set.add(msg) messages_set.add(msg) messages_set_local.add(msg) if len(colliding_messages_set) == 0: # There are no colliding messages between the groups, so we # will just aggregate them all simply. Note that we assume # that every group is a valid aggregate signature. If an invalid # or insecure signature is given, and invalid signature will # be created. We don't verify for performance reasons. final_sig = Signature.aggregate_sigs_simple(signatures) aggregation_infos = [sig.aggregation_info for sig in signatures] final_agg_info = AggregationInfo.merge_infos(aggregation_infos) final_sig.set_aggregation_info(final_agg_info) return final_sig # There are groups that share messages, therefore we need # to use a secure form of aggregation. First we find which # groups collide, and securely aggregate these. Then, we # use simple aggregation at the end. colliding_sigs = [] non_colliding_sigs = [] colliding_message_hashes = [] # List of lists colliding_public_keys = [] # List of lists for i in range(len(signatures)): group_collides = False for msg in message_hashes[i]: if msg in colliding_messages_set: group_collides = True colliding_sigs.append(signatures[i]) colliding_message_hashes.append(message_hashes[i]) colliding_public_keys.append(public_keys[i]) break if not group_collides: non_colliding_sigs.append(signatures[i]) # Arrange all signatures, sorted by their aggregation info colliding_sigs.sort(key=lambda s: s.aggregation_info) # Arrange all public keys in sorted order, by (m, pk) sort_keys_sorted = [] for i in range(len(colliding_public_keys)): for j in range(len(colliding_public_keys[i])): sort_keys_sorted.append((colliding_message_hashes[i][j], colliding_public_keys[i][j])) sort_keys_sorted.sort() sorted_public_keys = [pk for (mh, pk) in sort_keys_sorted] computed_Ts = BLS.hash_pks(len(colliding_sigs), sorted_public_keys) # Raise each sig to a power of each t, # and multiply all together into agg_sig ec = sorted_public_keys[0].value.ec agg_sig = JacobianPoint(Fq2.one(ec.q), Fq2.one(ec.q), Fq2.zero(ec.q), True, ec) for i, signature in enumerate(colliding_sigs): agg_sig += signature.value * computed_Ts[i] for signature in non_colliding_sigs: agg_sig += signature.value final_sig = Signature.from_g2(agg_sig) aggregation_infos = [sig.aggregation_info for sig in signatures] final_agg_info = AggregationInfo.merge_infos(aggregation_infos) final_sig.set_aggregation_info(final_agg_info) return final_sig