Beispiel #1
0
def align(input_set, party_id, endpoints, is_receiver=True):
    """
    Align the data owned by each data party.
    :param input_set: set. The id set of input data owned by this
    party.

    :param party_id: int. The id of this data party, which is
    natural number named from 0.
    :param endpoints: str. The info of all data parties,e.g.,
    id1:ip1:port1,id2:ip2:port2
    :param is_receiver: bool. True if this data party is a receiver
    role among all parties. Note that there is only one receiver
    who can obtain the result of aligning and then send it to other
    parties.
    :return: set. The intersection of data id.
    """
    all_parties = endpoints.split(",")
    _party_idx = _find_party_idx(party_id, all_parties)
    if _party_idx < 0:
        raise RuntimeError(
            "Could not find endpoint with id: {}".format(party_id))

    if is_receiver:
        del (all_parties[_party_idx])
        senders = all_parties
        result = input_set
        for sender in senders:
            ip_addr = sender.split(":")[1]
            port = int(sender.split(":")[2])
            result = mdu.recv_psi(ip_addr, port, result)
            result = set(result)
        # Only the receiver can obtain the result.
        # Send result to other parties.
        _send_align_result(result, senders)
    else:
        sender = all_parties[_party_idx]
        port = int(sender.split(":")[2])
        ret_code = mdu.send_psi(port, input_set)
        if ret_code != 0:
            raise RuntimeError("Errors occurred in PSI send lib, "
                               "error code = {}".format(ret_code))
        result = _recv_align_result(sender)
    return result
Beispiel #2
0
def one_hot_encoding_map(input_set, host_addr, is_client=True):
    """
    A protocol to get agreement between 2 parties for encoding one
    discrete feature to one hot vector via OT-PSI.

    Args:
        input_set (set:str): The set of possible feature value owned by this
        party. Element of set is str, convert before pass in.

        host_addr (str): The info of host_addr,e.g., ip:port

        is_receiver (bool): True if this party plays as socket client
        otherwise, plays as socket server

    Return Val: dict, int.
        dict key: feature values in input_set,
        dict value: corresponding idx in one hot vector.

        int: length of one hot vector for this feature.

    Examples:
        .. code-block:: python

            import paddle_fl.mpc.data_utils
            import sys

            is_client = sys.argv[1] == "1"

            a = set([str(x) for x in range(7)])
            b = set([str(x) for x in range(5, 10)])

            addr = "127.0.0.1:33784"

            ins = a if is_client else b

            x, y = paddle_fl.mpc.data_utils.one_hot_encoding_map(ins, addr, is_client)
            # y = 10
            # x['5'] = 0, x['6'] = 1
            # for those feature val owned only by one party, dict val shall
            not be conflicting.
            print(x, y)
    """

    ip = host_addr.split(":")[0]
    port = int(host_addr.split(":")[1])

    if is_client:
        intersection = input_set
        intersection = mdu.recv_psi(ip, port, intersection)
        intersection = sorted(list(intersection))
        # Only the receiver can obtain the result.
        # Send result to other parties.
    else:
        ret_code = mdu.send_psi(port, input_set)
        if ret_code != 0:
            raise RuntimeError("Errors occurred in PSI send lib, "
                               "error code = {}".format(ret_code))

    if not is_client:
        server = Listener((ip, port))
    conn = Client((ip, port)) if is_client else server.accept()

    if is_client:
        conn.send(intersection)

        diff_size_local = len(input_set) - len(intersection)
        conn.send(diff_size_local)
        diff_size_remote = conn.recv()

    else:
        intersection = conn.recv()

        diff_size_local = len(input_set) - len(intersection)

        diff_size_remote = conn.recv()
        conn.send(diff_size_local)

    conn.close()
    if not is_client:
        server.close()

    ret = dict()

    cnt = 0

    for x in intersection:
        ret[x] = cnt
        cnt += 1
    if is_client:
        cnt += diff_size_remote
    for x in [x for x in input_set if x not in intersection]:
        ret[x] = cnt
        cnt += 1

    return ret, len(intersection) + diff_size_local + diff_size_remote