def test_signaling_answer_message_serde(node: sy.VirtualMachine) -> None:
    target = Address(name="Alice")
    target_id = secrets.token_hex(nbytes=16)
    host_id = secrets.token_hex(nbytes=16)

    msg = SignalingAnswerMessage(
        address=target,
        payload="SDP",
        host_metadata=node.get_metadata_for_client(),
        target_peer=target_id,
        host_peer=host_id,
    )
    msg_metadata = node.get_metadata_for_client()

    blob = serialize(msg)
    msg2 = sy.deserialize(blob=blob)
    msg2_metadata = msg2.host_metadata

    assert msg.id == msg2.id
    assert msg.address == target
    assert msg.payload == msg2.payload
    assert msg2.payload == "SDP"
    assert msg2.host_peer == host_id
    assert msg2.target_peer == target_id
    assert msg == msg2

    assert msg_metadata.name == msg2_metadata.name
    assert msg_metadata.node == msg2_metadata.node
    assert msg_metadata.id == msg2_metadata.id
示例#2
0
def test_gc_simple_strategy(node: sy.VirtualMachine) -> None:
    client = node.get_client()

    x = torch.tensor([1, 2, 3, 4])
    ptr = x.send(client, pointable=False)

    assert len(node.store) == 1

    del ptr

    assert len(node.store) == 0
示例#3
0
def test_gc_change_default_gc_strategy(node: sy.VirtualMachine) -> None:
    gc_prev_strategy = gc_get_default_strategy()
    gc_set_default_strategy("gcbatched")

    client = node.get_client()

    res = isinstance(client.gc.gc_strategy, GCBatched)

    # Revert
    gc_set_default_strategy(gc_prev_strategy)
    sy.core.pointer.garbage_collection.GC_DEFAULT_STRATEGY = GCSimple

    assert res
示例#4
0
def test_gc_batched_strategy_gc_constructor(node: sy.VirtualMachine) -> None:
    client = node.get_client()
    client.gc = GarbageCollection("gcbatched", 5)

    x = torch.tensor([1, 2, 3, 4])

    for _ in range(4):
        x.send(client, pointable=False)

    assert len(node.store) == 4

    x.send(client, pointable=False)

    assert len(node.store) == 0
示例#5
0
def test_gc_batched_strategy_setter(node: sy.VirtualMachine) -> None:

    client = node.get_client()
    client.gc.gc_strategy = GCBatched(threshold=10)

    x = torch.tensor([1, 2, 3, 4])

    for _ in range(9):
        x.send(client, pointable=False)

    assert len(node.store) == 9

    x.send(client, pointable=False)

    assert len(node.store) == 0
示例#6
0
def test_gc_batched_delete_at_change(node: sy.VirtualMachine) -> None:
    client = node.get_client()

    # Change the strategy
    client.gc.gc_strategy = GCBatched()

    x = torch.tensor([1, 2, 3, 4])

    x.send(client, pointable=False)
    x.send(client, pointable=False)
    x.send(client, pointable=False)

    assert len(node.store) == 3

    # It should for the GCBatched to delete all the cached to-delete objs
    client.gc.gc_strategy = GCSimple()

    assert len(node.store) == 0
def test_child_node_lifecycle_message_serde(
        node: sy.VirtualMachine, client: sy.VirtualMachineClient) -> None:
    second_client = node.get_client()

    # bob_phone_client.register(client=bob_vm_client)
    # generates this message
    msg = RegisterChildNodeMessage(
        lookup_id=client.id,  # TODO: not sure if this is needed anymore
        child_node_client_address=client.address,
        address=second_client.address,
    )

    blob = serialize(msg)
    msg2 = sy.deserialize(blob=blob)

    assert msg.id == msg2.id
    assert msg.address == msg2.address
    assert msg.child_node_client_address == msg2.child_node_client_address
    assert msg == msg2
示例#8
0
def test_psi(loadlib_before_client: bool, reveal_intersection: bool,
             node: sy.VirtualMachine) -> None:
    # third party
    import openmined_psi as psi

    # it should work when call load before or after create clients
    if loadlib_before_client:
        sy.load("openmined_psi")
        server_vm = node.get_root_client()
        client_vm = node.get_root_client()
    else:
        server_vm = node.get_root_client()
        client_vm = node.get_root_client()
        sy.load("openmined_psi")

    # server send reveal_intersection
    s_reveal_intersection = reveal_intersection
    s_sy_reveal_intersection = sy.lib.python.Bool(s_reveal_intersection)
    s_sy_reveal_intersection.send(
        server_vm,
        pointable=True,
        tags=["reveal_intersection"],
        description="reveal intersection value",
    )
    assert (server_vm.store["reveal_intersection"].description ==
            "reveal intersection value")

    # client get reval_intersection
    c_reveal_intersection = server_vm.store["reveal_intersection"].get()
    assert c_reveal_intersection == s_reveal_intersection

    # server send fpr
    s_fpr = 1e-6
    s_sy_fpr = sy.lib.python.Float(s_fpr)
    s_sy_fpr.send(server_vm,
                  pointable=True,
                  tags=["fpr"],
                  description="false positive rate")

    # client get fpr
    c_fpr = server_vm.store["fpr"].get()
    assert c_fpr == approx(s_fpr)

    # client send client_items_len
    psi_client = psi.client.CreateWithNewKey(c_reveal_intersection)
    c_items = ["Element " + str(i) for i in range(1000)]
    c_sy_client_items_len = sy.lib.python.Int(len(c_items))
    c_sy_client_items_len.send(
        client_vm,
        pointable=True,
        tags=["client_items_len"],
        description="client items length",
    )

    # server get client_items_len
    s_sy_client_items_len = client_vm.store["client_items_len"].get(
        delete_obj=False)
    assert s_sy_client_items_len == c_sy_client_items_len

    # server send setup message
    s_items = ["Element " + str(2 * i) for i in range(1000)]
    psi_server = psi.server.CreateWithNewKey(s_reveal_intersection)
    s_setup = psi_server.CreateSetupMessage(s_fpr, s_sy_client_items_len,
                                            s_items)
    s_setup.send(
        server_vm,
        pointable=True,
        tags=["setup"],
        description="psi.server Setup Message",
    )
    assert server_vm.store["setup"].description == "psi.server Setup Message"

    # client get setup message
    c_setup = server_vm.store["setup"].get()
    assert c_setup == s_setup

    # client send request
    c_request = psi_client.CreateRequest(c_items)
    c_request.send(client_vm,
                   tags=["request"],
                   pointable=True,
                   description="client request")

    # server get request
    s_request = client_vm.store["request"].get()
    assert s_request == c_request

    # server send response
    s_response = psi_server.ProcessRequest(s_request)
    s_response.send(server_vm,
                    pointable=True,
                    tags=["response"],
                    description="psi.server response")

    # client get response
    c_response = server_vm.store["response"].get()
    assert c_response == s_response

    # client get result
    if c_reveal_intersection:
        intersection = psi_client.GetIntersection(c_setup, c_response)
        iset = set(intersection)
        for idx in range(len(c_items)):
            if idx % 2 == 0:
                assert idx in iset
            else:
                assert idx not in iset
    else:
        intersection = psi_client.GetIntersectionSize(c_setup, c_response)
        assert intersection >= (len(c_items) / 2.0)
        assert intersection <= (1.1 * len(c_items) / 2.0)
示例#9
0
def root_client(node: sy.VirtualMachine) -> sy.VirtualMachineClient:
    return node.get_root_client()
示例#10
0
def test_to_string(node: sy.VirtualMachine) -> None:
    assert str(node) == f"VirtualMachine: Bob: {node.id}"
    assert node.__repr__() == f"VirtualMachine: Bob: {node.id}"