def clients(): alice = sy.VirtualMachine(name="alice") bob = sy.VirtualMachine(name="bob") alice_client = alice.get_client() bob_client = bob.get_client() return [alice_client, bob_client]
def get_preset_nodes() -> Tuple[Node, Node, Node]: om_network = sy.Network(name="OpenMined") om_network.immediate_services_without_reply.append(PushSignalingService) om_network.immediate_services_with_reply.append(PullSignalingService) om_network.immediate_services_with_reply.append(RegisterDuetPeerService) om_network._register_services() # re-register all services including SignalingService bob_vm = sy.VirtualMachine(name="Bob") alice_vm = sy.VirtualMachine(name="Alice") return om_network, bob_vm, alice_vm
def test_load_sympc() -> None: alice = sy.VirtualMachine() alice_client = alice.get_root_client() bob = sy.VirtualMachine() bob_client = bob.get_root_client() session = Session(parties=[alice_client, bob_client]) SessionManager.setup_mpc(session) y = th.Tensor([-5, 0, 1, 2, 3]) x_secret = th.Tensor([30]) x = MPCTensor(secret=x_secret, shape=(1,), session=session) assert ((x + y).reconstruct() == th.Tensor([25.0, 30.0, 31.0, 32.0, 33.0])).all()
def test_torch_no_read_permissions() -> None: bob = sy.VirtualMachine(name="bob") root_bob = bob.get_root_client() guest_bob = bob.get_client() x = th.tensor([1, 2, 3, 4]) # root user of Bob's machine sends a tensor ptr = x.send(root_bob) # guest creates a pointer to that object (assuming the client can guess/infer the ID) ptr.client = guest_bob # this should trigger an exception with pytest.raises(AuthorizationException): ptr.get() x = th.tensor([1, 2, 3, 4]) # root user of Bob's machine sends a tensor ptr = x.send(root_bob) # but if root bob asks for it it should be fine x2 = ptr.get() assert (x == x2).all() assert x.grad == x2.grad
def test_signaling_answer_message_serde() -> None: bob_vm = sy.VirtualMachine(name="Bob") target = Address(name="Alice") target_id = secrets.token_hex(nbytes=16) host_id = secrets.token_hex(nbytes=16) msg = SignalingAnswerMessage( address=target, payload="SDP", host_metadata=bob_vm.get_metadata_for_client(), target_peer=target_id, host_peer=host_id, ) msg_metadata = bob_vm.get_metadata_for_client() blob = msg.serialize() msg2 = sy.deserialize(blob=blob) msg2_metadata = msg2.host_metadata assert msg.id == msg2.id assert msg.address == target assert msg.payload == msg2.payload assert msg2.payload == "SDP" assert msg2.host_peer == host_id assert msg2.target_peer == target_id assert msg == msg2 assert msg_metadata.name == msg2_metadata.name assert msg_metadata.node == msg2_metadata.node assert msg_metadata.id == msg2_metadata.id
def test_same_var_for_ptr_gc() -> None: """ Test for checking if the gc is correctly triggered when the last reference to the ptr is overwritten """ x = torch.tensor([1, 2, 3, 4]) alice = sy.VirtualMachine(name="alice") alice_client = alice.get_client() for _ in range(100): """ Override the ptr multiple times to make sure we trigger the gc """ ptr = x.send(alice_client) gc.collect() assert len(alice.store) == 1 ptr.get() gc.collect() assert len(alice.store) == 0
def test_run_function_or_constructor_action_serde() -> None: alice = sy.VirtualMachine(name="alice") alice_client = alice.get_client() args = ( th.tensor([1, 2, 3]).send(alice_client), th.tensor([4, 5, 5]).send(alice_client), ) msg = RunFunctionOrConstructorAction( path="torch.Tensor.add", args=args, kwargs={}, id_at_location=UID(), address=alice_client.address, msg_id=UID(), ) blob = msg.serialize() msg2 = sy.deserialize(blob=blob) assert msg2.path == msg.path # FIXME this cannot be checked before we fix the Pointer serde problem (see _proto2object in Pointer) # assert msg2.args == msg.args assert msg2.kwargs == msg.kwargs assert msg2.address == msg.address assert msg2.id == msg.id assert msg2.id_at_location == msg.id_at_location
def test_send_message_from_domain_client_to_vm() -> None: # Register a 🍰 with a 📱 # Register a 📱 with a 🏰 # Send ✉️ from 🏰 ➡️ 🍰 bob_vm = sy.VirtualMachine(name="Bob") bob_vm_client = bob_vm.get_client() bob_vm.root_verify_key = bob_vm_client.verify_key # inject 📡🔑 as 📍🗝 bob_phone = sy.Device(name="Bob's iPhone") bob_phone_client = bob_phone.get_client() bob_phone.root_verify_key = bob_phone_client.verify_key # inject 📡🔑 as 📍🗝 bob_phone_client.register(client=bob_vm_client) assert bob_vm.device is not None assert bob_vm_client.device is not None bob_domain = sy.Domain(name="Bob's Domain") bob_domain_client = bob_domain.get_client() bob_domain.root_verify_key = bob_domain_client.verify_key # inject 📡🔑 as 📍🗝 # switch keys bob_vm.root_verify_key = bob_domain_client.verify_key # inject 📡🔑 as 📍🗝 bob_domain_client.register(client=bob_phone_client) assert bob_phone.domain is not None assert bob_phone_client.domain is not None bob_domain_client.send_immediate_msg_without_reply(msg=sy.ReprMessage( address=bob_vm.address))
def test_user_module() -> None: alice = sy.VirtualMachine() alice_client = alice.get_root_client() # user defined model class M(th.nn.Module): def __init__(self) -> None: super(M, self).__init__() self.fc1 = th.nn.Linear(4, 2) self.fc2 = th.nn.Linear(2, 1) def forward(self, x: Any) -> Any: x = self.fc1(x) x = self.fc2(x) return x m = M() # send m_ptr = m.send(alice_client) # remote update state dict sd = OrderedDict(M().state_dict()) sd_ptr = sd.send(alice_client) m_ptr.load_state_dict(sd_ptr) # get sd2 = m_ptr.get().state_dict() assert (sd["fc1.weight"] == sd2["fc1.weight"]).all() assert (sd["fc1.bias"] == sd2["fc1.bias"]).all() assert (sd["fc2.weight"] == sd2["fc2.weight"]).all() assert (sd["fc2.bias"] == sd2["fc2.bias"]).all()
def test_torch_function() -> None: bob = sy.VirtualMachine(name="Bob") client = bob.get_client() x = th.tensor([[-0.1, 0.1], [0.2, 0.3]]) ptr_x = x.send(client) ptr_res = client.torch.zeros_like(ptr_x) res = ptr_res.get() assert (res == th.tensor([[0.0, 0.0], [0.0, 0.0]])).all()
def test_send_message_from_vm_client_to_vm() -> None: bob_vm = sy.VirtualMachine(name="Bob") bob_vm_client = bob_vm.get_client() assert bob_vm.device is None with pytest.raises(AuthorizationException): bob_vm_client.send_immediate_msg_without_reply(msg=sy.ReprMessage( address=bob_vm_client.address))
def test_device_init() -> None: bob = sy.VirtualMachine(name="Bob") client = bob.get_client() torch = client.torch type_str = String("cuda:0") str_pointer = type_str.send(client) device_pointer = torch.device(str_pointer) assert type(device_pointer).__name__ == "devicePointer" assert isinstance(device_pointer.id_at_location, UID)
def test_send() -> None: alice = sy.VirtualMachine(name="alice") alice_client = alice.get_client() syft_complex = Complex("2+3j") ptr = syft_complex.send(alice_client) # Check pointer type assert ptr.__class__.__name__ == "ComplexPointer" # Check that we can get back the object res = ptr.get() assert res == syft_complex
def test_send() -> None: alice = sy.VirtualMachine(name="alice") alice_client = alice.get_client() syft_float = Float(5) ptr = syft_float.send(alice_client) # Check pointer type assert ptr.__class__.__name__ == "FloatPointer" # Check that we can get back the object res = ptr.get() assert res == syft_float
def test_known_child_nodes() -> None: bob_vm = sy.VirtualMachine(name="Bob VM") bob_vm_client = bob_vm.get_client() bob_vm.root_verify_key = bob_vm_client.verify_key # inject 📡🔑 as 📍🗝 bob_vm_2 = sy.VirtualMachine(name="Bob VM 2") bob_vm_client_2 = bob_vm_2.get_client() bob_vm_2.root_verify_key = bob_vm_client_2.verify_key # inject 📡🔑 as 📍🗝 bob_phone = sy.Device(name="Bob's iPhone") bob_phone_client = bob_phone.get_client() bob_phone.root_verify_key = bob_phone_client.verify_key # inject 📡🔑 as 📍🗝 bob_phone_client.register(client=bob_vm_client) assert len(bob_phone.known_child_nodes) == 1 assert bob_vm in bob_phone.known_child_nodes bob_phone_client.register(client=bob_vm_client_2) assert len(bob_phone.known_child_nodes) == 2 assert bob_vm_2 in bob_phone.known_child_nodes
def test_string_send() -> None: alice = syft.VirtualMachine(name="alice") alice_client = alice.get_client() syft_string = String("Hello OpenMined!") ptr = syft_string.send(alice_client) # Check pointer type assert ptr.__class__.__name__ == "StringPointer" # Check that we can get back the object res = ptr.get() assert res == syft_string
def test_statement_zk_proof() -> None: vm = sy.VirtualMachine() client = vm.get_root_client() sy.load("zksk") # third party from zksk import DLRep from zksk import Secret from zksk import utils num = 2 seed = 42 num_sy = sy.lib.python.Int(num) seed_sy = sy.lib.python.Int(seed) # Setup: Peggy and Victor agree on two group generators. G, H = utils.make_generators(num=num, seed=seed) # Setup: generate a secret randomizer. r = Secret(utils.get_random_num(bits=128)) # This is Peggy's secret bit. top_secret_bit = 1 # A Pedersen commitment to the secret bit. C = top_secret_bit * G + r.value * H # Peggy's definition of the proof statement, and proof generation. # (The first or-clause corresponds to the secret value 0, and the second to the value 1. Because # the real value of the bit is 1, the clause that corresponds to zero is marked as simulated.) stmt = DLRep(C, r * H, simulated=True) | DLRep(C - G, r * H) zk_proof = stmt.prove() # send over the network and get back num_ptr = num_sy.send(client) seed_prt = seed_sy.send(client) c_ptr = C.send(client) zk_proof_ptr = zk_proof.send(client) num2 = num_ptr.get().upcast() seed2 = seed_prt.get().upcast() C2 = c_ptr.get() zk_proof2 = zk_proof_ptr.get() # Setup: get the agreed group generators. G, H = utils.make_generators(num=num2, seed=seed2) # Setup: define a randomizer with an unknown value. r = Secret() stmt = DLRep(C2, r * H) | DLRep(C2 - G, r * H) assert stmt.verify(zk_proof2)
def test_secret_serde() -> None: vm = sy.VirtualMachine() client = vm.get_root_client() # third party import zksk as zk sy.load("zksk") r = zk.Secret(zk.utils.get_random_num(bits=128)) r_ptr = r.send(client) r2 = r_ptr.get() assert r == r2
def test_register_vm_on_device_succeeds() -> None: # Register a 🍰 with a 📱 bob_vm = sy.VirtualMachine(name="Bob") bob_vm_client = bob_vm.get_client() bob_vm.root_verify_key = bob_vm_client.verify_key # inject 📡🔑 as 📍🗝 bob_phone = sy.Device(name="Bob's iPhone") bob_phone_client = bob_phone.get_client() bob_phone.root_verify_key = bob_phone_client.verify_key # inject 📡🔑 as 📍🗝 bob_phone_client.register(client=bob_vm_client) assert bob_vm.device is not None assert bob_vm_client.device is not None
def test_register_vm_on_device_fails() -> None: bob_vm = sy.VirtualMachine(name="Bob") bob_vm_client = bob_vm.get_client() bob_phone = sy.Device(name="Bob's iPhone") bob_phone_client = bob_phone.get_client() with pytest.raises(AuthorizationException): bob_phone_client.register(client=bob_vm_client) assert bob_vm.device is None # TODO: prevent device being set when Authorization fails assert bob_vm_client.device is not None
def test_list_send() -> None: alice = sy.VirtualMachine(name="alice") alice_client = alice.get_client() syft_list = Dict({ String("t1"): String("test"), String("t2"): String("test") }) ptr = syft_list.send(alice_client) # Check pointer type assert ptr.__class__.__name__ == "DictPointer" # Check that we can get back the object res = ptr.get() for res_el, original_el in zip(res, syft_list): assert res_el == original_el
def test_list_send() -> None: alice = sy.VirtualMachine(name="alice") alice_client = alice.get_client() t1 = th.tensor([1, 2]) t2 = th.tensor([1, 3]) syft_list = List([t1, t2]) ptr = syft_list.send(alice_client) # Check pointer type assert ptr.__class__.__name__ == "ListPointer" # Check that we can get back the object res = ptr.get() for res_el, original_el in zip(res, syft_list): assert (res_el == original_el).all()
def test_torch_garbage_method_creates_pointer() -> None: """ Test if sending a tensor and then deleting the pointer removes the object from the remote worker. """ alice = sy.VirtualMachine(name="alice") alice_client = alice.get_client() x = th.tensor([-1, 0, 1, 2, 3, 4]) x_ptr = x.send(alice_client) assert len(alice.store) == 1 gc.disable() x_ptr + 2 assert len(alice.store) == 3
def test_torch_remote_tensor_register() -> None: """ Test if sending a tensor will be registered on the remote worker. """ alice = sy.VirtualMachine(name="alice") alice_client = alice.get_client() x = th.tensor([-1, 0, 1, 2, 3, 4]) ptr = x.send(alice_client) assert len(alice.store) == 1 ptr = x.send(alice_client) gc.collect() # the previous objects get deleted because we overwrite # ptr - we send a message to delete that object assert len(alice.store) == 1 ptr.get() assert len(alice.store) == 0 # Get removes the object
def test_torch_remote_tensor_with_alias_send() -> None: """Test sending tensor on the remote worker with alias send method.""" alice = sy.VirtualMachine(name="alice") alice_client = alice.get_client() x = th.tensor([-1, 0, 1, 2, 3, 4]) ptr = x.send_to(alice_client) assert len(alice.store) == 1 # TODO: Fix this from deleting the object in the store due to the variable # see above # ptr = x.send_to(alice_client) data = ptr.get() assert len(alice.store) == 0 # Get removes the object assert x.equal(data) # Check if send data and received data are equal
def test_relu_module() -> None: alice = sy.VirtualMachine() alice_client = alice.get_root_client() # ReLU relu = th.nn.ReLU(inplace=True) # send relu_ptr = relu.send(alice_client) # remote call rand_data = th.rand([1, 4]) res_ptr = relu_ptr(rand_data) rand_output = res_ptr.get() assert rand_output.shape == th.Size((1, 4)) relu2 = relu_ptr.get() assert type(relu) == type(relu2) rand_output2 = relu2(rand_data) assert (rand_output2 == rand_output).all()
def test_torch_garbage_collect() -> None: """ Test if sending a tensor and then deleting the pointer removes the object from the remote worker. """ alice = sy.VirtualMachine(name="alice") alice_client = alice.get_client() x = th.tensor([-1, 0, 1, 2, 3, 4]) ptr = x.send(alice_client) assert len(alice.store) == 1 # "del" only decrements the counter and the garbage collector plays the role of the reaper del ptr # Make sure __del__ from Pointer is called gc.collect() assert len(alice.store) == 0
def test_child_node_lifecycle_message_serde() -> None: bob_vm = sy.VirtualMachine(name="Bob") bob_vm_client = bob_vm.get_client() bob_phone = sy.Device(name="Bob's iPhone") bob_phone_client = bob_phone.get_client() # bob_phone_client.register(client=bob_vm_client) # generates this message msg = RegisterChildNodeMessage( lookup_id=bob_vm_client.id, # TODO: not sure if this is needed anymore child_node_client_address=bob_vm_client.address, address=bob_phone_client.address, ) blob = msg.serialize() msg2 = sy.deserialize(blob=blob) assert msg.id == msg2.id assert msg.address == msg2.address assert msg.child_node_client_address == msg2.child_node_client_address assert msg == msg2
def test_signaling_answer_pull_request_message_serde() -> None: bob_vm = sy.VirtualMachine(name="Bob") target = Address(name="Alice") target_id = secrets.token_hex(nbytes=16) host_id = secrets.token_hex(nbytes=16) msg = AnswerPullRequestMessage( address=target, target_peer=target_id, host_peer=host_id, reply_to=bob_vm.address, ) blob = msg.serialize() msg2 = sy.deserialize(blob=blob) assert msg.id == msg2.id assert msg.address == target assert msg == msg2 assert msg2.host_peer == host_id assert msg2.target_peer == target_id
def test_get_copy() -> None: alice = sy.VirtualMachine(name="alice") alice_client = alice.get_client() x = th.nn.Parameter(th.randn(3, 3)) xp = x.send(alice_client) y = xp + xp assert len(alice.store._objects) == 2 y.get_copy() # no deletion of the object assert len(alice.store._objects) == 2 del xp gc.collect() assert len(alice.store._objects) == 1