def test_signaling_answer_message_serde(node: sy.VirtualMachine) -> None: target = Address(name="Alice") target_id = secrets.token_hex(nbytes=16) host_id = secrets.token_hex(nbytes=16) msg = SignalingAnswerMessage( address=target, payload="SDP", host_metadata=node.get_metadata_for_client(), target_peer=target_id, host_peer=host_id, ) msg_metadata = node.get_metadata_for_client() blob = serialize(msg) msg2 = sy.deserialize(blob=blob) msg2_metadata = msg2.host_metadata assert msg.id == msg2.id assert msg.address == target assert msg.payload == msg2.payload assert msg2.payload == "SDP" assert msg2.host_peer == host_id assert msg2.target_peer == target_id assert msg == msg2 assert msg_metadata.name == msg2_metadata.name assert msg_metadata.node == msg2_metadata.node assert msg_metadata.id == msg2_metadata.id
def test_gc_simple_strategy(node: sy.VirtualMachine) -> None: client = node.get_client() x = torch.tensor([1, 2, 3, 4]) ptr = x.send(client, pointable=False) assert len(node.store) == 1 del ptr assert len(node.store) == 0
def test_gc_change_default_gc_strategy(node: sy.VirtualMachine) -> None: gc_prev_strategy = gc_get_default_strategy() gc_set_default_strategy("gcbatched") client = node.get_client() res = isinstance(client.gc.gc_strategy, GCBatched) # Revert gc_set_default_strategy(gc_prev_strategy) sy.core.pointer.garbage_collection.GC_DEFAULT_STRATEGY = GCSimple assert res
def test_gc_batched_strategy_gc_constructor(node: sy.VirtualMachine) -> None: client = node.get_client() client.gc = GarbageCollection("gcbatched", 5) x = torch.tensor([1, 2, 3, 4]) for _ in range(4): x.send(client, pointable=False) assert len(node.store) == 4 x.send(client, pointable=False) assert len(node.store) == 0
def test_gc_batched_strategy_setter(node: sy.VirtualMachine) -> None: client = node.get_client() client.gc.gc_strategy = GCBatched(threshold=10) x = torch.tensor([1, 2, 3, 4]) for _ in range(9): x.send(client, pointable=False) assert len(node.store) == 9 x.send(client, pointable=False) assert len(node.store) == 0
def test_gc_batched_delete_at_change(node: sy.VirtualMachine) -> None: client = node.get_client() # Change the strategy client.gc.gc_strategy = GCBatched() x = torch.tensor([1, 2, 3, 4]) x.send(client, pointable=False) x.send(client, pointable=False) x.send(client, pointable=False) assert len(node.store) == 3 # It should for the GCBatched to delete all the cached to-delete objs client.gc.gc_strategy = GCSimple() assert len(node.store) == 0
def test_child_node_lifecycle_message_serde( node: sy.VirtualMachine, client: sy.VirtualMachineClient) -> None: second_client = node.get_client() # bob_phone_client.register(client=bob_vm_client) # generates this message msg = RegisterChildNodeMessage( lookup_id=client.id, # TODO: not sure if this is needed anymore child_node_client_address=client.address, address=second_client.address, ) blob = serialize(msg) msg2 = sy.deserialize(blob=blob) assert msg.id == msg2.id assert msg.address == msg2.address assert msg.child_node_client_address == msg2.child_node_client_address assert msg == msg2
def test_psi(loadlib_before_client: bool, reveal_intersection: bool, node: sy.VirtualMachine) -> None: # third party import openmined_psi as psi # it should work when call load before or after create clients if loadlib_before_client: sy.load("openmined_psi") server_vm = node.get_root_client() client_vm = node.get_root_client() else: server_vm = node.get_root_client() client_vm = node.get_root_client() sy.load("openmined_psi") # server send reveal_intersection s_reveal_intersection = reveal_intersection s_sy_reveal_intersection = sy.lib.python.Bool(s_reveal_intersection) s_sy_reveal_intersection.send( server_vm, pointable=True, tags=["reveal_intersection"], description="reveal intersection value", ) assert (server_vm.store["reveal_intersection"].description == "reveal intersection value") # client get reval_intersection c_reveal_intersection = server_vm.store["reveal_intersection"].get() assert c_reveal_intersection == s_reveal_intersection # server send fpr s_fpr = 1e-6 s_sy_fpr = sy.lib.python.Float(s_fpr) s_sy_fpr.send(server_vm, pointable=True, tags=["fpr"], description="false positive rate") # client get fpr c_fpr = server_vm.store["fpr"].get() assert c_fpr == approx(s_fpr) # client send client_items_len psi_client = psi.client.CreateWithNewKey(c_reveal_intersection) c_items = ["Element " + str(i) for i in range(1000)] c_sy_client_items_len = sy.lib.python.Int(len(c_items)) c_sy_client_items_len.send( client_vm, pointable=True, tags=["client_items_len"], description="client items length", ) # server get client_items_len s_sy_client_items_len = client_vm.store["client_items_len"].get( delete_obj=False) assert s_sy_client_items_len == c_sy_client_items_len # server send setup message s_items = ["Element " + str(2 * i) for i in range(1000)] psi_server = psi.server.CreateWithNewKey(s_reveal_intersection) s_setup = psi_server.CreateSetupMessage(s_fpr, s_sy_client_items_len, s_items) s_setup.send( server_vm, pointable=True, tags=["setup"], description="psi.server Setup Message", ) assert server_vm.store["setup"].description == "psi.server Setup Message" # client get setup message c_setup = server_vm.store["setup"].get() assert c_setup == s_setup # client send request c_request = psi_client.CreateRequest(c_items) c_request.send(client_vm, tags=["request"], pointable=True, description="client request") # server get request s_request = client_vm.store["request"].get() assert s_request == c_request # server send response s_response = psi_server.ProcessRequest(s_request) s_response.send(server_vm, pointable=True, tags=["response"], description="psi.server response") # client get response c_response = server_vm.store["response"].get() assert c_response == s_response # client get result if c_reveal_intersection: intersection = psi_client.GetIntersection(c_setup, c_response) iset = set(intersection) for idx in range(len(c_items)): if idx % 2 == 0: assert idx in iset else: assert idx not in iset else: intersection = psi_client.GetIntersectionSize(c_setup, c_response) assert intersection >= (len(c_items) / 2.0) assert intersection <= (1.1 * len(c_items) / 2.0)
def root_client(node: sy.VirtualMachine) -> sy.VirtualMachineClient: return node.get_root_client()
def test_to_string(node: sy.VirtualMachine) -> None: assert str(node) == f"VirtualMachine: Bob: {node.id}" assert node.__repr__() == f"VirtualMachine: Bob: {node.id}"