def get_compute_and_return_union(communication, clients, round_num, ps_computation_time_path): """Get ids from live clients, compute union, and then return union""" # get ids from online clients and store them union_id_dict = {} temp_clients = [client for client in clients] for client in temp_clients: try: received_message = communication.get_np_array( client.connection_socket) assert client.ID == received_message['client_ID'] union_id_dict[client.ID] = received_message print('Received real item ids from client ' + str(client.ID)) sys.stdout.flush() except Exception: clients.remove(client) print('Fallen client: ' + str(client.ID) + ' at ' + client.address[0] + ':' + str(client.address[1]) + ' in the real item ids uploading stage.') sys.stdout.flush() client.connection_socket.close() del temp_clients # generate returned union message start_time = time.time() union_message = compute_union(union_id_dict) write_csv(ps_computation_time_path, [ round_num, "plain set union of real item ids", time.time() - start_time ]) # send back union send_back_union(communication, clients, union_message)
def client_side_sfsa_round3(communication, client_socket, FEDSUBAVG_SELF_STORAGE, FEDSUBAVG_OTHERS_STORAGE): """ Send shares of b (for self mask) and ssk (for mutual mask) for live and dropped clients in U2 and U1\U2, respectively. """ start_time = time.time() # U2: Except myself fedsubavg_u2_live = list(set(FEDSUBAVG_SELF_STORAGE['U2']) - set([FEDSUBAVG_SELF_STORAGE['my_index']])) # U1/U2 fedsubavg_u2_drop = FEDSUBAVG_SELF_STORAGE['U1\U2'] # Shares of self mask's seed for live clients fedsubavg_live_b_shares = dict() for client_index_live in fedsubavg_u2_live: fedsubavg_live_b_shares[client_index_live] = FEDSUBAVG_OTHERS_STORAGE[client_index_live]['share_b'] fedsubavg_live_b_shares[FEDSUBAVG_SELF_STORAGE['my_index']] = FEDSUBAVG_SELF_STORAGE['my_share_b'] # Shares of mutual mask's secret key for dropped clients fedsubavg_drop_s_shares = dict() for client_index_drop in fedsubavg_u2_drop: fedsubavg_drop_s_shares[client_index_drop] = FEDSUBAVG_OTHERS_STORAGE[client_index_drop]['share_ssk'] write_csv(FEDSUBAVG_SELF_STORAGE['client_computation_time_path'], [FEDSUBAVG_SELF_STORAGE['communication_round_number'], \ "sfsa_U3", time.time() - start_time]) # Send shares to the server fedsubavg_shares = {'client_ID': FEDSUBAVG_SELF_STORAGE['my_index'], 'live_b_shares': fedsubavg_live_b_shares, 'drop_s_shares': fedsubavg_drop_s_shares} communication.send_np_array(fedsubavg_shares, client_socket) print('Client %d sent secret shares of live and dropped clients in round 2 to server in secure federated submodel averaging'\ % FEDSUBAVG_SELF_STORAGE['my_index']) sys.stdout.flush() del fedsubavg_live_b_shares del fedsubavg_drop_s_shares
def server_side_psu_round0(communication, clients, UNION_SERVER_STORAGE, UNION_ROUND_STORAGE): """ Receive public keys from all live clients Then, Broadcast them. """ temp_clients = [client for client in clients] for client in temp_clients: try: received_message = communication.get_np_array(client.connection_socket) assert client.ID == received_message['client_ID'] UNION_SERVER_STORAGE.setdefault(client.ID, {})['spk'] = received_message['spk'] UNION_SERVER_STORAGE.setdefault(client.ID, {})['cpk'] = received_message['cpk'] print('Received public keys from client ' + str(client.ID) + ' in private set union round 0') sys.stdout.flush() except Exception: clients.remove(client) print('Fallen client: ' + str(client.ID) + ' at ' + client.address[0] + ':' + str( client.address[1]) + ' in private set union round 0 (receive)') sys.stdout.flush() client.connection_socket.close() # Record update-to-date set of client indices in round 0 UNION_ROUND_STORAGE['U0'] = [client.ID for client in clients] UNION_ROUND_STORAGE['n'] = len(UNION_ROUND_STORAGE['U0']) UNION_ROUND_STORAGE['t'] = int(UNION_ROUND_STORAGE['n'] / 2) + 1 # At least 2 clients to participate (n = 2, t = 2) # Did not receive public keys from enough clients. Abort! assert UNION_ROUND_STORAGE['n'] >= 2 start_time = time.time() # Gather public keys of all live clients in U0 union_pubkeys_dict = {} for client_index in UNION_ROUND_STORAGE['U0']: union_pubkeys_dict[client_index] = {} union_pubkeys_dict[client_index]['spk'] = UNION_SERVER_STORAGE[client_index]['spk'] union_pubkeys_dict[client_index]['cpk'] = UNION_SERVER_STORAGE[client_index]['cpk'] write_csv(UNION_ROUND_STORAGE['ps_computation_time_path'], [UNION_ROUND_STORAGE['communication_round_num'], "psu_U0", time.time() - start_time]) # Return gathered public keys to each live client temp_clients = [client for client in clients] for client in temp_clients: try: communication.send_np_array(union_pubkeys_dict, client.connection_socket) print('Returned all public keys to client ' + str(client.ID) + ' in private set union round 0') sys.stdout.flush() except Exception: clients.remove(client) print('Fallen client: ' + str(client.ID) + ' at ' + client.address[0] + ':' + str( client.address[1]) + ' in private set union round 0 (return)') sys.stdout.flush() client.connection_socket.close() del temp_clients del union_pubkeys_dict
def client_side_sfsa_round0(communication, client_socket, FEDSUBAVG_SELF_STORAGE, FEDSUBAVG_OTHERS_STORAGE, FEDSUBAVG_DHKE): """ Generate secret and public Keys, and Send public keys to server. Also, receive public keys of other clients from server. This part can be merged with that in private set union, but for clarify we separate them. """ start_time_1 = time.time() # Generate the 2 pair of Diffie-Hellman keys # "s" to generate the seed for the shared mask, and "c" to generate the shared symmetric encryption key # my_csk, my_cpk can actually use those in private set union!!! my_ssk, my_spk = FEDSUBAVG_DHKE.generate_keys() my_csk, my_cpk = FEDSUBAVG_DHKE.generate_keys() # Store the previously generated keys FEDSUBAVG_SELF_STORAGE['my_ssk'] = my_ssk FEDSUBAVG_SELF_STORAGE['my_spk'] = my_spk FEDSUBAVG_SELF_STORAGE['my_csk'] = my_csk FEDSUBAVG_SELF_STORAGE['my_cpk'] = my_cpk fedsubavg_client_pubkeys = { 'client_ID': FEDSUBAVG_SELF_STORAGE['my_index'], 'spk': my_spk, 'cpk': my_cpk } end_time_1 = time.time() communication.send_np_array(fedsubavg_client_pubkeys, client_socket) print( 'Client %d sent public keys to server in secure federated submodel averaging' % FEDSUBAVG_SELF_STORAGE['my_index']) sys.stdout.flush() fedsubavg_pubkeys_dict = communication.get_np_array(client_socket) print('Received public keys of all clients from server.') start_time_2 = time.time() for client_index, pubkeys in fedsubavg_pubkeys_dict.items(): # Does not need to store my own keys (already in FEDSUBAVG_SELF_STORAGE) if client_index != FEDSUBAVG_SELF_STORAGE['my_index']: FEDSUBAVG_OTHERS_STORAGE.setdefault(client_index, {})['spk'] = pubkeys['spk'] FEDSUBAVG_OTHERS_STORAGE.setdefault(client_index, {})['cpk'] = pubkeys['cpk'] # Record number of live clients (including client self) and the required threshold FEDSUBAVG_SELF_STORAGE['n'] = len(fedsubavg_pubkeys_dict) FEDSUBAVG_SELF_STORAGE['t'] = int(FEDSUBAVG_SELF_STORAGE['n'] / 2) + 1 end_time_2 = time.time() write_csv(FEDSUBAVG_SELF_STORAGE['client_computation_time_path'], [FEDSUBAVG_SELF_STORAGE['communication_round_number'],\ "sfsa_U0", end_time_1 - start_time_1 + end_time_2 - start_time_2])
def client_side_psu_round0(communication, client_socket, UNION_SELF_STORAGE, UNION_OTHERS_STORAGE, UNION_DHKE): """ Generate secret and public Keys, and Send public keys to server. Also, receive public keys of other clients from server. """ start_time_1 = time.time() # Generate the 2 pair of Diffie-Hellman keys # "s" to generate the seed for the shared mask, and "c" to generate the shared symmetric encryption key my_ssk, my_spk = UNION_DHKE.generate_keys() my_csk, my_cpk = UNION_DHKE.generate_keys() # Store the previously generated keys UNION_SELF_STORAGE['my_ssk'] = my_ssk UNION_SELF_STORAGE['my_spk'] = my_spk UNION_SELF_STORAGE['my_csk'] = my_csk UNION_SELF_STORAGE['my_cpk'] = my_cpk union_client_pubkeys = { 'client_ID': UNION_SELF_STORAGE['my_index'], 'spk': my_spk, 'cpk': my_cpk } end_time_1 = time.time() communication.send_np_array(union_client_pubkeys, client_socket) print('Client %d sent public keys to server in private set union' % UNION_SELF_STORAGE['my_index']) sys.stdout.flush() union_pubkeys_dict = communication.get_np_array(client_socket) print('Received public keys of all clients from server.') start_time_2 = time.time() for client_index, pubkeys in union_pubkeys_dict.items(): # Does not need to store my own keys (already in UNION_SELF_STORAGE) if client_index != UNION_SELF_STORAGE['my_index']: UNION_OTHERS_STORAGE.setdefault(client_index, {})['spk'] = pubkeys['spk'] UNION_OTHERS_STORAGE.setdefault(client_index, {})['cpk'] = pubkeys['cpk'] # Record number of live clients (including client self) and the required threshold UNION_SELF_STORAGE['n'] = len(union_pubkeys_dict) UNION_SELF_STORAGE['t'] = int(UNION_SELF_STORAGE['n'] / 2) + 1 end_time_2 = time.time() write_csv(UNION_SELF_STORAGE['client_computation_time_path'], [UNION_SELF_STORAGE['communication_round_number'], \ "psu_U0", end_time_1 - start_time_1 + end_time_2 - start_time_2])
def client_side_psu_round3(communication, client_socket, UNION_SELF_STORAGE, UNION_OTHERS_STORAGE): """ Send shares of b (for self mask) and ssk (for mutual mask) for live and dropped clients in U2 and U1\U2, respectively. """ start_time = time.time() # U2: Except myself union_u2_live = list( set(UNION_SELF_STORAGE['U2']) - set([UNION_SELF_STORAGE['my_index']])) # U1/U2 union_u2_drop = UNION_SELF_STORAGE['U1\U2'] # Shares of self mask's seed for live clients union_live_b_shares = {} for client_index_live in union_u2_live: union_live_b_shares[client_index_live] = UNION_OTHERS_STORAGE[ client_index_live]['share_b'] union_live_b_shares[ UNION_SELF_STORAGE['my_index']] = UNION_SELF_STORAGE['my_share_b'] # Shares of mutual mask's secret key for dropped clients union_drop_s_shares = {} for client_index_drop in union_u2_drop: union_drop_s_shares[client_index_drop] = UNION_OTHERS_STORAGE[ client_index_drop]['share_ssk'] write_csv(UNION_SELF_STORAGE['client_computation_time_path'], [UNION_SELF_STORAGE['communication_round_number'], \ "psu_U3", time.time() - start_time]) # Send shares to the server union_shares = { 'client_ID': UNION_SELF_STORAGE['my_index'], 'live_b_shares': union_live_b_shares, 'drop_s_shares': union_drop_s_shares } communication.send_np_array(union_shares, client_socket) print('Client %d sent secret shares of live and dropped clients in round 2 to server in private set union'\ % UNION_SELF_STORAGE['my_index']) sys.stdout.flush()
def server_side_psu_round1(communication, clients, UNION_SERVER_STORAGE, UNION_ROUND_STORAGE): """ Receive encrypted secret shares from clients and forward them """ temp_clients = [client for client in clients] for client in temp_clients: try: received_message = communication.get_np_array( client.connection_socket) assert client.ID == received_message['client_ID'] UNION_SERVER_STORAGE[ client.ID]['ss_ciphers_dict'] = received_message['ss_ciphers'] print('Received encrypted secret shares from client ' + str(client.ID) + ' in private set union round 1') sys.stdout.flush() except Exception: clients.remove(client) print('Fallen client: ' + str(client.ID) + ' at ' + client.address[0] + ':' + str(client.address[1]) + ' in private set union round 1 (receive)') sys.stdout.flush() client.connection_socket.close() # Record update-to-date set of client indices in round 1 UNION_ROUND_STORAGE['U1'] = [client.ID for client in clients] # Record set of client indices live in round 0 but drop in round 1 # I.e., Those clients who submit public keys but do not submit encrypted secret shares UNION_ROUND_STORAGE['U0\U1'] = list( set(UNION_ROUND_STORAGE['U1']) - set(UNION_ROUND_STORAGE['U0'])) # Did not receive encrypted secret shares from enough clients. Abort! assert len(UNION_ROUND_STORAGE['U1']) >= UNION_ROUND_STORAGE['t'] start_time = time.time() # Instead of having a dictionary of messages FROM a given client, we want to construct # a dictionary of messages TO a given client. ss_ciphers_dicts_FROM = {} for client_index in UNION_ROUND_STORAGE['U1']: ss_ciphers_dicts_FROM[client_index] = UNION_SERVER_STORAGE[ client_index]['ss_ciphers_dict'] # This is here that we reverse the "FROM key TO value" dict to a "FROM value TO key" dict # e.g.: {1: {2:a, 3:b, 4:c}, 3: {1:d,2:e,4:f}, 4: {1:g,2:h,3:i}} --> {1: {3:d, 4:g}, 3:{1:b, 4:i}, 4: {1:c,3:f}} ss_ciphers_dicts_TO = {} # forward message "enc_msg_to_client_index" from "from_client_index" to "to_client_index" for from_client_index, enc_msg_from_client_index in ss_ciphers_dicts_FROM.items( ): for to_client_index, enc_msg_to_client_index in enc_msg_from_client_index.items( ): ss_ciphers_dicts_TO.setdefault( to_client_index, {})[from_client_index] = enc_msg_to_client_index write_csv(UNION_ROUND_STORAGE['ps_computation_time_path'], [ UNION_ROUND_STORAGE['communication_round_num'], "psu_U1", time.time() - start_time ]) # Forward encrypted secret shares to each live client temp_clients = [client for client in clients] for client in temp_clients: try: communication.send_np_array(ss_ciphers_dicts_TO[client.ID], client.connection_socket) print('Forwarded encrypted secret shares to client ' + str(client.ID) + ' in private set union round 1') sys.stdout.flush() time.sleep( 5 ) #create asynchronous environments to avoid clients competing for CPU resources except Exception: clients.remove(client) print('Fallen client: ' + str(client.ID) + ' at ' + client.address[0] + ':' + str(client.address[1]) + ' in private set union round 1 (return)') sys.stdout.flush() client.connection_socket.close() del temp_clients del ss_ciphers_dicts_FROM del ss_ciphers_dicts_TO
def server_side_psu_round3(communication, clients, UNION_SERVER_STORAGE, UNION_ROUND_STORAGE, union_security_para_dict, UNION_DHKE): """ Receive mask related shares form live clients, and perform unmasking. """ temp_clients = [client for client in clients] for client in temp_clients: try: received_message = communication.get_np_array( client.connection_socket) assert client.ID == received_message['client_ID'] UNION_SERVER_STORAGE[ client.ID]['live_b_shares'] = received_message['live_b_shares'] UNION_SERVER_STORAGE[ client.ID]['drop_s_shares'] = received_message['drop_s_shares'] print('Received mask related shares from client ' + str(client.ID) + ' in private set union round 3') sys.stdout.flush() except Exception: clients.remove(client) print('Fallen client: ' + str(client.ID) + ' at ' + client.address[0] + ':' + str(client.address[1]) + ' in private set union round 3 (receive)') sys.stdout.flush() client.connection_socket.close() # Record update-to-date set of client indices in round 3 UNION_ROUND_STORAGE['U3'] = [client.ID for client in clients] # Record set of client indices live in round 2 but dropped in round 3 # Those clients who submit public keys, encrypted secret shares, masked input, but do not submit mask related shares UNION_ROUND_STORAGE['U2\U3'] = list( set(UNION_ROUND_STORAGE['U2']) - set(UNION_ROUND_STORAGE['U3'])) # Did not receive mask related shares from enough clients. assert len(UNION_ROUND_STORAGE['U3']) >= UNION_ROUND_STORAGE['t'] start_time_1 = time.time() # Compute final output z by removing self masks of live clients in U2 and mutual mask with dropped clients in U1/U2 item_count = union_security_para_dict['item_count'] modulo_r = union_security_para_dict['modulo_r'] z = np.zeros(item_count, dtype='int64') # First add each 'y' for client_index in UNION_ROUND_STORAGE['U2']: z += UNION_SERVER_STORAGE[client_index]['y'] end_time_1 = time.time() write_csv(UNION_ROUND_STORAGE['ps_computation_time_path'], [ UNION_ROUND_STORAGE['communication_round_num'], "psu_U3_add_y", end_time_1 - start_time_1 ]) start_time_2 = time.time() # Second, reconstruct and then remove self mask of U2, using asynchronous multiprocessing union_b_mask_sum = server_side_psu_reconstruct_self_mask_parallel(UNION_SERVER_STORAGE, UNION_ROUND_STORAGE,\ union_security_para_dict) z -= union_b_mask_sum end_time_2 = time.time() write_csv(UNION_ROUND_STORAGE['ps_computation_time_path'], [ UNION_ROUND_STORAGE['communication_round_num'], "psu_U3_sub_b_parallel", end_time_2 - start_time_2 ]) start_time_3 = time.time() # Third, reconstruct and then remove mutual mask of U1/U2, using asynchronous multiprocessing union_s_mask_sum = server_side_psu_reconstruct_mutual_mask_parallel(UNION_SERVER_STORAGE, UNION_ROUND_STORAGE, \ union_security_para_dict, UNION_DHKE) z -= union_s_mask_sum z %= modulo_r end_time_3 = time.time() write_csv(UNION_ROUND_STORAGE['ps_computation_time_path'], [ UNION_ROUND_STORAGE['communication_round_num'], "psu_U3_sub_s_parallel", end_time_3 - start_time_3 ]) write_csv(UNION_ROUND_STORAGE['ps_computation_time_path'], [UNION_ROUND_STORAGE['communication_round_num'], "psu_U3_total", \ end_time_1 - start_time_1 + end_time_2 - start_time_2 + end_time_3 - start_time_3]) print("Private set union round 3 in parallel costs %f s at ps" % (end_time_1 - start_time_1 + end_time_2 - start_time_2 + end_time_3 - start_time_3)) sys.stdout.flush() return z
def client_side_psu_round1(communication, client_socket, UNION_SELF_STORAGE, UNION_OTHERS_STORAGE, \ union_security_para_dict, UNION_DHKE): """ Generate and send encrypted secret shares for PRNG seed and ssk """ start_time_1 = time.time() # Generate seed for PRNG seed_len = union_security_para_dict['seed_len'] union_b_entropy = os.urandom(seed_len / 8) #bytes union_b = bytes2int(union_b_entropy) t = UNION_SELF_STORAGE['t'] n = UNION_SELF_STORAGE['n'] # Generate t-out-of-n shares for PRNG's seed b union_shares_b = SecretSharer.split_secret(union_b, t, n) # Generate t-out-of-n shares for client's ssk union_shares_my_ssk = SecretSharer.split_secret( UNION_SELF_STORAGE['my_ssk'], t, n) # Store random seed, and secret shares into self dictionary UNION_SELF_STORAGE['b_entropy'] = union_b_entropy ''' UNION_SELF_STORAGE['b'] = union_b UNION_SELF_STORAGE['shares_b'] = union_shares_b UNION_SELF_STORAGE['shares_my_ssk'] = union_shares_my_ssk ''' # Store my share of b in isolation # No need to store my share of my ssk, since I am alive to myself! union_my_share_b = union_shares_b[0] union_shares_b = list(set(union_shares_b) - set([union_my_share_b])) UNION_SELF_STORAGE['my_share_b'] = union_my_share_b union_ss_ciphers_dict = {} for idx, client_index in enumerate(UNION_OTHERS_STORAGE.keys()): # Derive symmetric encryption key "agreed" with other client (with client_index) (via Diffie-Hellman Agreement) sym_enc_key = UNION_DHKE.agree( UNION_SELF_STORAGE['my_csk'], UNION_OTHERS_STORAGE[client_index]['cpk']) # Send ciphertext to other client (with client_index), where PS works as a mediation msg = str(UNION_SELF_STORAGE['my_index']) + ' || ' + str(client_index) + ' || ' + str(union_shares_b[idx]) \ + ' || ' + str(union_shares_my_ssk[idx]) # Encrypt with AES_CBC enc_msg = AESCipher(str(sym_enc_key)).encrypt(msg) union_ss_ciphers_dict[client_index] = enc_msg UNION_OTHERS_STORAGE[client_index]['sym_enc_key'] = sym_enc_key ''' UNION_OTHERS_STORAGE[client_index]['msg'] = msg UNION_OTHERS_STORAGE[client_index]['enc_msg'] = enc_msg ''' end_time_1 = time.time() # send encrypted shares to the server union_ss_ciphers_send_message = { 'client_ID': UNION_SELF_STORAGE['my_index'], 'ss_ciphers': union_ss_ciphers_dict } communication.send_np_array(union_ss_ciphers_send_message, client_socket) print( 'Client %d sent encrypted secret shares to server in private set union' % UNION_SELF_STORAGE['my_index']) sys.stdout.flush() # receive other clients' encrypted shares to me from the server ss_ciphers_dict_received = communication.get_np_array(client_socket) print("Received other clients' encrypted secret shares from server.") sys.stdout.flush() start_time_2 = time.time() for client_index, enc_msg in ss_ciphers_dict_received.items(): # Decrypt the encrypted message and parse it sym_enc_key = UNION_OTHERS_STORAGE[client_index]['sym_enc_key'] msg = AESCipher(str(sym_enc_key)).decrypt(enc_msg) msg_parts = msg.split(' || ') # Sanity check from_client_index = int(msg_parts[0]) my_index = int(msg_parts[1]) assert from_client_index == client_index and my_index == UNION_SELF_STORAGE[ 'my_index'] # Store secret shares of other clients UNION_OTHERS_STORAGE[client_index]['share_b'] = msg_parts[2] UNION_OTHERS_STORAGE[client_index]['share_ssk'] = msg_parts[3] # clients in U1 (except myself) for mutual masks UNION_SELF_STORAGE[ 'mutual_mask_client_indices'] = ss_ciphers_dict_received.keys() end_time_2 = time.time() write_csv(UNION_SELF_STORAGE['client_computation_time_path'], [UNION_SELF_STORAGE['communication_round_number'], \ "psu_U1", end_time_1 - start_time_1 + end_time_2 - start_time_2])
def client_side_private_set_union(communication, client_socket, client_index, real_itemIDs, union_security_para_dict, union_u2_drop_flag, union_u3_drop_flag, round_num, client_computation_time_path): """ Main function for client to join Private Set Union (PSU) through perturbed Bloom filter and Secure Aggregation """ # This dictionary will contain all the values generated by this client herself UNION_SELF_STORAGE = {} # This dictionary will contain all the values about the OTHER clients. It is keyed by client_index UNION_OTHERS_STORAGE = {} UNION_SELF_STORAGE['my_index'] = client_index UNION_SELF_STORAGE['communication_round_number'] = round_num UNION_SELF_STORAGE[ 'client_computation_time_path'] = client_computation_time_path # ID 14 - 2048-bit MODP group for Diffie-Hellman Key Exchange UNION_DHKE = DHKE(groupID=14) client_side_psu_round0(communication, client_socket, UNION_SELF_STORAGE, UNION_OTHERS_STORAGE, UNION_DHKE) client_side_psu_round1(communication, client_socket, UNION_SELF_STORAGE, UNION_OTHERS_STORAGE, \ union_security_para_dict, UNION_DHKE) if union_u2_drop_flag: # drop from round 2 client_socket.close() print( "Client %d drops out from Private Set Union U2 in this communication round \n" % client_index) print( '-----------------------------------------------------------------' ) print('') print('') sys.stdout.flush() return [] # Represent real_itemIDs as a perturbed Bloom filter start_time = time.time() real_itemIDs_pbf = represent_set_as_perturbed_bloom_filter( real_itemIDs, union_security_para_dict) write_csv( client_computation_time_path, [round_num, "psu_generated_perturbed_bf", time.time() - start_time]) UNION_SELF_STORAGE['x'] = real_itemIDs_pbf client_side_psu_round2(communication, client_socket, UNION_SELF_STORAGE, UNION_OTHERS_STORAGE, \ union_security_para_dict, UNION_DHKE) if union_u3_drop_flag: # drop from round 3 client_socket.close() print( "Client %d drops out from Private Set Union U3 in this communication round \n" % client_index) print( '-----------------------------------------------------------------' ) print('') print('') return [] client_side_psu_round3(communication, client_socket, UNION_SELF_STORAGE, UNION_OTHERS_STORAGE) # receive union real_itemIDs_union = communication.get_np_array(client_socket) print('Received union of real item ids (via Private Set Union).') sys.stdout.flush() del UNION_SELF_STORAGE del UNION_OTHERS_STORAGE return real_itemIDs_union
def client_side_psu_round2(communication, client_socket, UNION_SELF_STORAGE, UNION_OTHERS_STORAGE, \ union_security_para_dict, UNION_DHKE): """ Doubly mask the input vector, i.e., perturbed Bloom filter, and send it to server. """ start_time = time.time() # Load parameters for PRNG seed_len = union_security_para_dict['seed_len'] security_strength = union_security_para_dict['security_strength'] modulo_r_len = union_security_para_dict['modulo_r_len'] modulo_r = union_security_para_dict['modulo_r'] item_count = union_security_para_dict[ 'item_count'] # length of perturbed Bloom filter (x) # Generate self mask union_b_entropy = UNION_SELF_STORAGE['b_entropy'] union_DRBG_b = HMAC_DRBG(union_b_entropy, security_strength) union_b_mask = prng(union_DRBG_b, modulo_r_len, security_strength, item_count) ''' UNION_SELF_STORAGE['b_mask'] = union_b_mask ''' # Generate mutual mask union_mutual_mask = np.zeros(item_count, dtype='int64') for client_index in UNION_SELF_STORAGE[ 'mutual_mask_client_indices']: #U1 except myself # Derive seed for mutual mask, i.e., agreed key, with other client (u, v via Diffie-Hellman Agreement) s_uv = UNION_DHKE.agree(UNION_SELF_STORAGE['my_ssk'], UNION_OTHERS_STORAGE[client_index]['spk']) s_uv_modulo = s_uv % (2**seed_len) s_uv_entropy = int2bytes(s_uv_modulo, seed_len / 8) union_DRGB_s = HMAC_DRBG(s_uv_entropy, security_strength) union_s_mask = prng(union_DRGB_s, modulo_r_len, security_strength, item_count) # Minus the mask when other client is with larger index, # or add the mask when other client is with smaller index sgn = np.sign(UNION_SELF_STORAGE['my_index'] - client_index) union_mutual_mask += sgn * union_s_mask ''' # Store mutual mask related info UNION_OTHERS_STORAGE[client_index]['s'] = s_uv UNION_OTHERS_STORAGE[client_index]['s_mask'] = union_s_mask ''' # Add self and mutual masks # Here is the final output "y" to send to server y = (UNION_SELF_STORAGE['x'] + union_b_mask + union_mutual_mask) % modulo_r data_type = determine_data_type(modulo_r_len) y = y.astype(data_type) ''' UNION_SELF_STORAGE['y'] = y ''' write_csv(UNION_SELF_STORAGE['client_computation_time_path'], [UNION_SELF_STORAGE['communication_round_number'], \ "psu_U2_y", time.time() - start_time]) # Send masked input to the server union_client_y = {'client_ID': UNION_SELF_STORAGE['my_index'], 'y': y} communication.send_np_array(union_client_y, client_socket) print( 'Client %d sent masked perturbed Bloom filter to server in private set union' % UNION_SELF_STORAGE['my_index']) sys.stdout.flush() # Receive Online and Offline sets of clients in round 2 round2_clients_status = communication.get_np_array(client_socket) print("Received clients' status in round 2 from server") sys.stdout.flush() UNION_SELF_STORAGE['U2'] = round2_clients_status['U2'] UNION_SELF_STORAGE['U1\U2'] = round2_clients_status['U1\U2']
def update_global_model(round_num, communication, clients, dataset_info, batches_info_dict, hyperparameters, \ fedsubavg_security_para_dict, placeholders, update_local_vars_op, variables_pack_for_eval_and_save, \ global_model_auc_path, ps_computation_time_path, g1, sess): """ Aggregate all submodel updates in this round and apply to the global model """ #prepare some parameters in clear with g1.as_default(): global_model_shape = [ para.shape for para in sess.run(tf.trainable_variables()) ] userIDs_dict = dict() # client_index -> userID global_perturbed_itemIDs = [] global_perturbed_cateIDs = [] any_client_flag = True for client_index, client_ids_info in batches_info_dict.items(): userIDs_dict[client_index] = client_ids_info['userID'][0] if any_client_flag: global_perturbed_itemIDs = client_ids_info['perturbed_itemIDs'] global_perturbed_cateIDs = client_ids_info['perturbed_cateIDs'] any_client_flag = False # (Delete finally) Use plaintext protocol to aggregate submodel parameters and count numbers for debugging gathered_weights_dict = {} get_submodel_update(communication, clients, batches_info_dict, gathered_weights_dict) start_time = time.time() client_indices = [client.ID for client in clients] gathered_weighted_delta_submodel_plain, gathered_userIDs_count_plain, gathered_itemIDs_count_plain,\ gathered_cateIDs_count_plain, gathered_other_count_plain = \ gather_weighted_submodel_updates(client_indices, dataset_info, gathered_weights_dict, batches_info_dict, global_model_shape, fedsubavg_security_para_dict) write_csv(ps_computation_time_path, [round_num, "plain_sfsa", time.time() - start_time]) # Secure federated submodel averaging works here!!! print("PS side secure federated submodel averaging starts") sys.stdout.flush() gathered_weighted_delta_submodel, gathered_userIDs_count, gathered_itemIDs_count, gathered_cateIDs_count, \ gathered_other_count = ps_sfsa.server_side_secure_federated_submodel_averaging(communication, clients, \ dataset_info, userIDs_dict, global_perturbed_itemIDs, global_perturbed_cateIDs, global_model_shape, \ fedsubavg_security_para_dict, round_num, ps_computation_time_path) assert (gathered_userIDs_count_plain == gathered_userIDs_count).all() assert (gathered_itemIDs_count_plain == gathered_itemIDs_count).all() assert (gathered_cateIDs_count_plain == gathered_cateIDs_count).all() assert gathered_other_count_plain == gathered_other_count assert_model_flag = False for layer, para_shape in enumerate(global_model_shape): if not (gathered_weighted_delta_submodel_plain[layer] == gathered_weighted_delta_submodel[layer]).all(): print(layer, para_shape) assert_model_flag = True if assert_model_flag: print("Oh, no! Secure federated submodel averaging fails!") exit(-1) else: print( "Yes! Secure federated submodel averaging successfully finishes!") sys.stdout.flush() with g1.as_default(): start_time = time.time() if hyperparameters['fl_flag']: new_global_model = fl_average_submodel_updates_and_apply_global_update(gathered_weighted_delta_submodel, \ gathered_userIDs_count, gathered_itemIDs_count, gathered_cateIDs_count, gathered_other_count, \ hyperparameters, sess) else: new_global_model = average_submodel_updates_and_apply_global_update(gathered_weighted_delta_submodel, \ gathered_userIDs_count, gathered_itemIDs_count, gathered_cateIDs_count, gathered_other_count,\ hyperparameters, sess) # Update global model at parameter server do_update_weights(new_global_model, placeholders, update_local_vars_op, sess) write_csv(ps_computation_time_path, [ round_num, "ps avg and then update global_model", time.time() - start_time ]) print( 'Round {}: Weights received, average applied '.format(round_num) + 'among {} clients'.format(len(clients)) + ', model updated! Evaluating...') sys.stdout.flush() # Update learning rate hyperparameters['learning_rate'] *= hyperparameters['decay_rate'] # Evaluate global model test_auc, loss_sum, accuracy_sum = my_eval.eval_and_save( variables_pack_for_eval_and_save, round_num, sess) write_csv(global_model_auc_path, [round_num, test_auc, loss_sum, accuracy_sum]) print( 'Global Model performance: test_auc: %.4f ---- loss: %f ---- accuracy: %f' % (test_auc, loss_sum, accuracy_sum)) print('Best round: ' + str(variables_pack_for_eval_and_save['best_round']) + ' Best test_auc: ' + str(variables_pack_for_eval_and_save['best_auc'])) print('') sys.stdout.flush()
print("P(j in S'' | j in S) = %.3f"%prob5) print("P(j in S'' | j not in S) = %.3f"%prob6) ''' ###################################################### # Phase 1: Prepare real index set # Then, represent as perturbed Bloom Filter # Finally, Participate in the union of real index sets (Through Secure Aggregation) ###################################################### temp_start_time = time.time() userID, real_itemIDs, real_train_set_size = cl_pl_fn.extract_real_index_set( client_index) gn_fn.write_csv( client_computation_time_path, [round_num, "extract real itemIDs", time.time() - temp_start_time]) print("Real train set size: %d" % real_train_set_size) ''' # (Delete finally) Use plaintext protocol for debugging. Do not affect later procedures. real_itemIDs_union_plain = cl_pl_fn.client_side_set_union(communication, client_socket, client_index, real_itemIDs,\ union_u2_drop_flag) # reply to checking "online" state cl_pl_fn.response_to_check_connection(client_socket, 1.1) # Client Side Private Set Union print("Client %d side private set union starts!"%client_index) sys.stdout.flush() real_itemIDs_union = cl_psu.client_side_private_set_union(communication, client_socket, client_index, real_itemIDs,\
def client_side_sfsa_round1(communication, client_socket, FEDSUBAVG_SELF_STORAGE, FEDSUBAVG_OTHERS_STORAGE, \ fedsubavg_security_para_dict, FEDSUBAVG_DHKE): """ Generate and send encrypted secret shares for PRNG seed and ssk. This can be merged with that in private set union, but for clarity, we still do not do so. Different from private set union here is that, the client also receives the indices of other clients for mutual mask. Specifically, we need to handle mutual masks for the embedding layers of item ids and cate ids in a ``submodel" way. """ start_time_1 = time.time() # Generate seed for PRNG seed_len = fedsubavg_security_para_dict['seed_len'] fedsubavg_b_entropy = os.urandom(seed_len / 8) #bytes fedsubavg_b = bytes2int(fedsubavg_b_entropy) t = FEDSUBAVG_SELF_STORAGE['t'] n = FEDSUBAVG_SELF_STORAGE['n'] # Generate t-out-of-n shares for PRNG's seed b fedsubavg_shares_b = SecretSharer.split_secret(fedsubavg_b, t, n) # Generate t-out-of-n shares for client's ssk fedsubavg_shares_my_ssk = SecretSharer.split_secret( FEDSUBAVG_SELF_STORAGE['my_ssk'], t, n) # Store random seed, and secret shares into self dictionary FEDSUBAVG_SELF_STORAGE['b_entropy'] = fedsubavg_b_entropy ''' FEDSUBAVG_SELF_STORAGE['b'] = fedsubavg_b FEDSUBAVG_SELF_STORAGE['shares_b'] = fedsubavg_shares_b FEDSUBAVG_SELF_STORAGE['shares_my_ssk'] = fedsubavg_shares_my_ssk ''' # Store my share of b in isolation # No need to store my share of my ssk, since I am alive to myself! fedsubavg_my_share_b = fedsubavg_shares_b[0] fedsubavg_shares_b = list( set(fedsubavg_shares_b) - set([fedsubavg_my_share_b])) FEDSUBAVG_SELF_STORAGE['my_share_b'] = fedsubavg_my_share_b fedsubavg_ss_ciphers_dict = {} for idx, client_index in enumerate( FEDSUBAVG_OTHERS_STORAGE.keys()): # Already except myself # Derive symmetric encryption key "agreed" with other client (with client_index) (via Diffie-Hellman Agreement) sym_enc_key = FEDSUBAVG_DHKE.agree( FEDSUBAVG_SELF_STORAGE['my_csk'], FEDSUBAVG_OTHERS_STORAGE[client_index]['cpk']) # Send ciphertext to other client (with client_index), where PS works as a mediation msg = str(FEDSUBAVG_SELF_STORAGE['my_index']) + ' || ' + str(client_index) + ' || ' + str(fedsubavg_shares_b[idx]) \ + ' || ' + str(fedsubavg_shares_my_ssk[idx]) # Encrypt with AES_CBC enc_msg = AESCipher(str(sym_enc_key)).encrypt(msg) fedsubavg_ss_ciphers_dict[client_index] = enc_msg FEDSUBAVG_OTHERS_STORAGE[client_index]['sym_enc_key'] = sym_enc_key ''' FEDSUBAVG_OTHERS_STORAGE[client_index]['msg'] = msg FEDSUBAVG_OTHERS_STORAGE[client_index]['enc_msg'] = enc_msg ''' end_time_1 = time.time() # send encrypted shares to the server fedsubavg_ss_ciphers_send_message = { 'client_ID': FEDSUBAVG_SELF_STORAGE['my_index'], 'ss_ciphers': fedsubavg_ss_ciphers_dict } communication.send_np_array(fedsubavg_ss_ciphers_send_message, client_socket) print( 'Client %d sent encrypted secret shares to server in secure federated submodel averaging' % FEDSUBAVG_SELF_STORAGE['my_index']) sys.stdout.flush() # Receive other clients' encrypted shares and indices for mutual mask to me from the server round1_returned_message = communication.get_np_array(client_socket) print( "Received other clients' encrypted secret shares and indices for mutual mask from server" ) sys.stdout.flush() start_time_2 = time.time() # Decrypt the secret shares and store them ss_ciphers_dict_received = round1_returned_message['ss_ciphers_dict'] for client_index, enc_msg in ss_ciphers_dict_received.items(): # Decrypt the encrypted message and parse it sym_enc_key = FEDSUBAVG_OTHERS_STORAGE[client_index]['sym_enc_key'] msg = AESCipher(str(sym_enc_key)).decrypt(enc_msg) msg_parts = msg.split(' || ') # Sanity check from_client_index = int(msg_parts[0]) my_index = int(msg_parts[1]) assert from_client_index == client_index and my_index == FEDSUBAVG_SELF_STORAGE[ 'my_index'] # Store secret shares of other clients FEDSUBAVG_OTHERS_STORAGE[client_index]['share_b'] = msg_parts[2] FEDSUBAVG_OTHERS_STORAGE[client_index]['share_ssk'] = msg_parts[3] # Indices of other clients (except myself) for mutual mask U1\Client Self FEDSUBAVG_SELF_STORAGE[ 'mutual_mask_general_client_indices'] = round1_returned_message[ 'mutual_mask_general_client_indices'] end_time_2 = time.time() write_csv(FEDSUBAVG_SELF_STORAGE['client_computation_time_path'], [FEDSUBAVG_SELF_STORAGE['communication_round_number'], \ "sfsa_U1", end_time_1 - start_time_1 + end_time_2 - start_time_2])
def client_side_sfsa_round2(communication, client_socket, FEDSUBAVG_SELF_STORAGE, FEDSUBAVG_OTHERS_STORAGE, \ fedsubavg_x_dict, fedsubavg_security_para_dict, FEDSUBAVG_DHKE): """ Doubly mask the input, including weighted_delta_submodel, perturbed_itemIDs_count, perturbed_cateIDs_count, and send it to server. Here no mask for the embedding layer for one userID and the train_set_size. """ start_time_1 = time.time() # Load original "input" x_dict weighted_delta_submodel = fedsubavg_x_dict['weighted_delta_submodel'] perturbed_userID_count = fedsubavg_x_dict['perturbed_userID_count'] perturbed_itemIDs_count = fedsubavg_x_dict['perturbed_itemIDs_count'] perturbed_cateIDs_count = fedsubavg_x_dict['perturbed_cateIDs_count'] perturbed_other_count = fedsubavg_x_dict['perturbed_other_count'] # Prepare the shape and sizes for facilitate generating masks fedsubavg_shapes_dict = dict() fedsubavg_shapes_dict['submodel_shape'] = [ para.shape for para in weighted_delta_submodel ] # list of tuples fedsubavg_shapes_dict['perturbed_itemIDs_size'] = len( perturbed_itemIDs_count) fedsubavg_shapes_dict['perturbed_cateIDs_size'] = len( perturbed_cateIDs_count) # First, generate self mask fedsubavg_b_entropy = FEDSUBAVG_SELF_STORAGE['b_entropy'] fedsubavg_b_mask_dict = client_side_sfsa_generate_self_mask( fedsubavg_b_entropy, fedsubavg_shapes_dict, fedsubavg_security_para_dict) end_time_1 = time.time() write_csv(FEDSUBAVG_SELF_STORAGE['client_computation_time_path'], [FEDSUBAVG_SELF_STORAGE['communication_round_number'], \ "sfsa_U2_b_dict", end_time_1 - start_time_1]) sys.stdout.flush() print("Client %d side secure federated submodel learning self mask generation costs %f s" \ %(FEDSUBAVG_SELF_STORAGE['my_index'], end_time_1 - start_time_1)) start_time_2 = time.time() # Second, generate mutual mask fedsubavg_s_mask_dict = client_side_sfsa_generate_mutual_mask(FEDSUBAVG_SELF_STORAGE, FEDSUBAVG_OTHERS_STORAGE, \ fedsubavg_shapes_dict, fedsubavg_security_para_dict, FEDSUBAVG_DHKE) end_time_2 = time.time() write_csv(FEDSUBAVG_SELF_STORAGE['client_computation_time_path'], [FEDSUBAVG_SELF_STORAGE['communication_round_number'], \ "sfsa_U2_s_dict", end_time_2 - start_time_2]) sys.stdout.flush() print("Client %d side secure federated submodel learning mutual mask generation costs %f s" \ % (FEDSUBAVG_SELF_STORAGE['my_index'], end_time_2 - start_time_2)) start_time_3 = time.time() # Third, Add self and mutual masks to the original input x_dict, and derive the final output "y_dict" to send to server # For submodel parameters weighted_delta_submodel_masked = [ np.zeros(para_shape, dtype='int64') for para_shape in fedsubavg_shapes_dict['submodel_shape'] ] modulo_model_r = fedsubavg_security_para_dict['modulo_model_r'] modulo_model_r_len = fedsubavg_security_para_dict['modulo_model_r_len'] model_data_type = determine_data_type(modulo_model_r_len) for layer, para_shape in enumerate( fedsubavg_shapes_dict['submodel_shape']): if layer == 0: # embedding for user id weighted_delta_submodel_masked[ layer] = weighted_delta_submodel[layer] % modulo_model_r else: # Attention: Please do not forget to add original x !!! weighted_delta_submodel_masked[layer] += weighted_delta_submodel[ layer] weighted_delta_submodel_masked[layer] += fedsubavg_b_mask_dict[ 'weighted_delta_submodel'][layer] weighted_delta_submodel_masked[layer] += fedsubavg_s_mask_dict[ 'weighted_delta_submodel'][layer] weighted_delta_submodel_masked[layer] %= modulo_model_r weighted_delta_submodel_masked = [ weights.astype(model_data_type) for weights in weighted_delta_submodel_masked ] # For count numbers modulo_count_r = fedsubavg_security_para_dict['modulo_count_r'] modulo_count_r_len = fedsubavg_security_para_dict['modulo_count_r_len'] count_data_type = determine_data_type(modulo_count_r_len) perturbed_itemIDs_count_masked = perturbed_itemIDs_count + fedsubavg_b_mask_dict['perturbed_itemIDs_count'] + \ fedsubavg_s_mask_dict['perturbed_itemIDs_count'] perturbed_itemIDs_count_masked %= modulo_count_r perturbed_itemIDs_count_masked = perturbed_itemIDs_count_masked.astype( count_data_type) perturbed_cateIDs_count_masked = perturbed_cateIDs_count + fedsubavg_b_mask_dict['perturbed_cateIDs_count'] + \ fedsubavg_s_mask_dict['perturbed_cateIDs_count'] perturbed_cateIDs_count_masked %= modulo_count_r perturbed_cateIDs_count_masked = perturbed_cateIDs_count_masked.astype( count_data_type) # All outputs y_dict = { 'weighted_delta_submodel_masked': weighted_delta_submodel_masked, 'perturbed_itemIDs_count_masked': perturbed_itemIDs_count_masked, 'perturbed_cateIDs_count_masked': perturbed_cateIDs_count_masked, 'perturbed_userID_count': perturbed_userID_count, 'perturbed_other_count': perturbed_other_count } end_time_3 = time.time() write_csv(FEDSUBAVG_SELF_STORAGE['client_computation_time_path'], [FEDSUBAVG_SELF_STORAGE['communication_round_number'], \ "sfsa_U2_add_b_s_to_y_dict", end_time_3 - start_time_3]) write_csv(FEDSUBAVG_SELF_STORAGE['client_computation_time_path'], [FEDSUBAVG_SELF_STORAGE['communication_round_number'], \ "sfsa_U2_y_dict_total", end_time_1 - start_time_1 + end_time_2 - start_time_2 + end_time_3 - start_time_3]) # Send masked input to the server fedsubavg_client_y = { 'client_ID': FEDSUBAVG_SELF_STORAGE['my_index'], 'y_dict': y_dict } communication.send_np_array(fedsubavg_client_y, client_socket) print('Client %d sent masked submodel parameters and count numbers to server in secure federated submodel learning'\ % FEDSUBAVG_SELF_STORAGE['my_index']) sys.stdout.flush() # Receive Online and Offline sets of clients in round 2 round2_clients_status = communication.get_np_array(client_socket) print("Received clients' status in round 2 from server") sys.stdout.flush() FEDSUBAVG_SELF_STORAGE['U2'] = round2_clients_status['U2'] FEDSUBAVG_SELF_STORAGE['U1\U2'] = round2_clients_status['U1\U2']
################################################################ ## Main Starts Here! ################################################################ round_num = 0 # test initialized model's auc test_auc, loss_sum, accuracy_sum = my_eval.eval_and_save( variables_pack_for_eval_and_save, round_num, sess) print( '-----------------------------------------------------------------------') print('Initialized model: test_auc: %.4f ---- loss: %f ---- accuracy: %f' % (test_auc, loss_sum, accuracy_sum)) print( '-----------------------------------------------------------------------') sys.stdout.flush() gn_fn.write_csv(global_model_auc_path, [round_num, test_auc, loss_sum, accuracy_sum]) for round_num in range(1, communication_rounds + 1): print('Round %d starts!' % round_num) time_start = time.time() clients = [] ps_pl_fn.send_hyperparameters(communication, clients, total_users_num, chosen_clients_num, hyperparameters, \ union_security_para_dict, fedsubavg_security_para_dict) # Update the communication socket set of online clients # In fact, for synchronization usage before your intended stage ps_pl_fn.check_connection(clients) # (Delete finally) Use plaintext protocol to compute and return union for debugging ps_pl_fn.get_compute_and_return_union(communication, clients, round_num, ps_computation_time_path) ps_pl_fn.check_connection(clients)