def align(input_set, party_id, endpoints, is_receiver=True): """ Align the data owned by each data party. :param input_set: set. The id set of input data owned by this party. :param party_id: int. The id of this data party, which is natural number named from 0. :param endpoints: str. The info of all data parties,e.g., id1:ip1:port1,id2:ip2:port2 :param is_receiver: bool. True if this data party is a receiver role among all parties. Note that there is only one receiver who can obtain the result of aligning and then send it to other parties. :return: set. The intersection of data id. """ all_parties = endpoints.split(",") _party_idx = _find_party_idx(party_id, all_parties) if _party_idx < 0: raise RuntimeError( "Could not find endpoint with id: {}".format(party_id)) if is_receiver: del (all_parties[_party_idx]) senders = all_parties result = input_set for sender in senders: ip_addr = sender.split(":")[1] port = int(sender.split(":")[2]) result = mdu.recv_psi(ip_addr, port, result) result = set(result) # Only the receiver can obtain the result. # Send result to other parties. _send_align_result(result, senders) else: sender = all_parties[_party_idx] port = int(sender.split(":")[2]) ret_code = mdu.send_psi(port, input_set) if ret_code != 0: raise RuntimeError("Errors occurred in PSI send lib, " "error code = {}".format(ret_code)) result = _recv_align_result(sender) return result
def one_hot_encoding_map(input_set, host_addr, is_client=True): """ A protocol to get agreement between 2 parties for encoding one discrete feature to one hot vector via OT-PSI. Args: input_set (set:str): The set of possible feature value owned by this party. Element of set is str, convert before pass in. host_addr (str): The info of host_addr,e.g., ip:port is_receiver (bool): True if this party plays as socket client otherwise, plays as socket server Return Val: dict, int. dict key: feature values in input_set, dict value: corresponding idx in one hot vector. int: length of one hot vector for this feature. Examples: .. code-block:: python import paddle_fl.mpc.data_utils import sys is_client = sys.argv[1] == "1" a = set([str(x) for x in range(7)]) b = set([str(x) for x in range(5, 10)]) addr = "127.0.0.1:33784" ins = a if is_client else b x, y = paddle_fl.mpc.data_utils.one_hot_encoding_map(ins, addr, is_client) # y = 10 # x['5'] = 0, x['6'] = 1 # for those feature val owned only by one party, dict val shall not be conflicting. print(x, y) """ ip = host_addr.split(":")[0] port = int(host_addr.split(":")[1]) if is_client: intersection = input_set intersection = mdu.recv_psi(ip, port, intersection) intersection = sorted(list(intersection)) # Only the receiver can obtain the result. # Send result to other parties. else: ret_code = mdu.send_psi(port, input_set) if ret_code != 0: raise RuntimeError("Errors occurred in PSI send lib, " "error code = {}".format(ret_code)) if not is_client: server = Listener((ip, port)) conn = Client((ip, port)) if is_client else server.accept() if is_client: conn.send(intersection) diff_size_local = len(input_set) - len(intersection) conn.send(diff_size_local) diff_size_remote = conn.recv() else: intersection = conn.recv() diff_size_local = len(input_set) - len(intersection) diff_size_remote = conn.recv() conn.send(diff_size_local) conn.close() if not is_client: server.close() ret = dict() cnt = 0 for x in intersection: ret[x] = cnt cnt += 1 if is_client: cnt += diff_size_remote for x in [x for x in input_set if x not in intersection]: ret[x] = cnt cnt += 1 return ret, len(intersection) + diff_size_local + diff_size_remote