コード例 #1
0
def clients():
    alice = sy.VirtualMachine(name="alice")
    bob = sy.VirtualMachine(name="bob")

    alice_client = alice.get_client()
    bob_client = bob.get_client()

    return [alice_client, bob_client]
コード例 #2
0
def get_preset_nodes() -> Tuple[Node, Node, Node]:
    om_network = sy.Network(name="OpenMined")
    om_network.immediate_services_without_reply.append(PushSignalingService)
    om_network.immediate_services_with_reply.append(PullSignalingService)
    om_network.immediate_services_with_reply.append(RegisterDuetPeerService)
    om_network._register_services()  # re-register all services including SignalingService
    bob_vm = sy.VirtualMachine(name="Bob")
    alice_vm = sy.VirtualMachine(name="Alice")
    return om_network, bob_vm, alice_vm
コード例 #3
0
ファイル: sympc_test.py プロジェクト: stoic-signs/PySyft
def test_load_sympc() -> None:
    alice = sy.VirtualMachine()
    alice_client = alice.get_root_client()
    bob = sy.VirtualMachine()
    bob_client = bob.get_root_client()

    session = Session(parties=[alice_client, bob_client])
    SessionManager.setup_mpc(session)

    y = th.Tensor([-5, 0, 1, 2, 3])
    x_secret = th.Tensor([30])
    x = MPCTensor(secret=x_secret, shape=(1,), session=session)

    assert ((x + y).reconstruct() == th.Tensor([25.0, 30.0, 31.0, 32.0, 33.0])).all()
コード例 #4
0
def test_torch_no_read_permissions() -> None:

    bob = sy.VirtualMachine(name="bob")
    root_bob = bob.get_root_client()
    guest_bob = bob.get_client()

    x = th.tensor([1, 2, 3, 4])

    # root user of Bob's machine sends a tensor
    ptr = x.send(root_bob)

    # guest creates a pointer to that object (assuming the client can guess/infer the ID)
    ptr.client = guest_bob

    # this should trigger an exception
    with pytest.raises(AuthorizationException):
        ptr.get()

    x = th.tensor([1, 2, 3, 4])

    # root user of Bob's machine sends a tensor
    ptr = x.send(root_bob)

    # but if root bob asks for it it should be fine
    x2 = ptr.get()

    assert (x == x2).all()

    assert x.grad == x2.grad
コード例 #5
0
def test_signaling_answer_message_serde() -> None:
    bob_vm = sy.VirtualMachine(name="Bob")
    target = Address(name="Alice")
    target_id = secrets.token_hex(nbytes=16)
    host_id = secrets.token_hex(nbytes=16)

    msg = SignalingAnswerMessage(
        address=target,
        payload="SDP",
        host_metadata=bob_vm.get_metadata_for_client(),
        target_peer=target_id,
        host_peer=host_id,
    )
    msg_metadata = bob_vm.get_metadata_for_client()

    blob = msg.serialize()
    msg2 = sy.deserialize(blob=blob)
    msg2_metadata = msg2.host_metadata

    assert msg.id == msg2.id
    assert msg.address == target
    assert msg.payload == msg2.payload
    assert msg2.payload == "SDP"
    assert msg2.host_peer == host_id
    assert msg2.target_peer == target_id
    assert msg == msg2

    assert msg_metadata.name == msg2_metadata.name
    assert msg_metadata.node == msg2_metadata.node
    assert msg_metadata.id == msg2_metadata.id
コード例 #6
0
ファイル: gc_test.py プロジェクト: znreza/PySyft
def test_same_var_for_ptr_gc() -> None:
    """
    Test for checking if the gc is correctly triggered
    when the last reference to the ptr is overwritten
    """
    x = torch.tensor([1, 2, 3, 4])

    alice = sy.VirtualMachine(name="alice")
    alice_client = alice.get_client()

    for _ in range(100):
        """
        Override the ptr multiple times to make sure we trigger
        the gc
        """
        ptr = x.send(alice_client)

    gc.collect()

    assert len(alice.store) == 1

    ptr.get()
    gc.collect()

    assert len(alice.store) == 0
コード例 #7
0
def test_run_function_or_constructor_action_serde() -> None:
    alice = sy.VirtualMachine(name="alice")
    alice_client = alice.get_client()

    args = (
        th.tensor([1, 2, 3]).send(alice_client),
        th.tensor([4, 5, 5]).send(alice_client),
    )

    msg = RunFunctionOrConstructorAction(
        path="torch.Tensor.add",
        args=args,
        kwargs={},
        id_at_location=UID(),
        address=alice_client.address,
        msg_id=UID(),
    )

    blob = msg.serialize()

    msg2 = sy.deserialize(blob=blob)

    assert msg2.path == msg.path
    # FIXME this cannot be checked before we fix the Pointer serde problem (see _proto2object in Pointer)
    # assert msg2.args == msg.args
    assert msg2.kwargs == msg.kwargs
    assert msg2.address == msg.address
    assert msg2.id == msg.id
    assert msg2.id_at_location == msg.id_at_location
コード例 #8
0
ファイル: node_test.py プロジェクト: znreza/PySyft
def test_send_message_from_domain_client_to_vm() -> None:
    # Register a 🍰 with a 📱
    # Register a 📱 with a 🏰
    # Send ✉️ from 🏰 ➡️ 🍰

    bob_vm = sy.VirtualMachine(name="Bob")
    bob_vm_client = bob_vm.get_client()
    bob_vm.root_verify_key = bob_vm_client.verify_key  # inject 📡🔑 as 📍🗝

    bob_phone = sy.Device(name="Bob's iPhone")
    bob_phone_client = bob_phone.get_client()
    bob_phone.root_verify_key = bob_phone_client.verify_key  # inject 📡🔑 as 📍🗝

    bob_phone_client.register(client=bob_vm_client)

    assert bob_vm.device is not None
    assert bob_vm_client.device is not None

    bob_domain = sy.Domain(name="Bob's Domain")
    bob_domain_client = bob_domain.get_client()
    bob_domain.root_verify_key = bob_domain_client.verify_key  # inject 📡🔑 as 📍🗝

    # switch keys
    bob_vm.root_verify_key = bob_domain_client.verify_key  # inject 📡🔑 as 📍🗝
    bob_domain_client.register(client=bob_phone_client)

    assert bob_phone.domain is not None
    assert bob_phone_client.domain is not None

    bob_domain_client.send_immediate_msg_without_reply(msg=sy.ReprMessage(
        address=bob_vm.address))
コード例 #9
0
ファイル: module_serde_test.py プロジェクト: yashlamba/PySyft
def test_user_module() -> None:
    alice = sy.VirtualMachine()
    alice_client = alice.get_root_client()

    # user defined model
    class M(th.nn.Module):
        def __init__(self) -> None:
            super(M, self).__init__()
            self.fc1 = th.nn.Linear(4, 2)
            self.fc2 = th.nn.Linear(2, 1)

        def forward(self, x: Any) -> Any:
            x = self.fc1(x)
            x = self.fc2(x)
            return x

    m = M()

    # send
    m_ptr = m.send(alice_client)

    # remote update state dict
    sd = OrderedDict(M().state_dict())
    sd_ptr = sd.send(alice_client)
    m_ptr.load_state_dict(sd_ptr)

    # get
    sd2 = m_ptr.get().state_dict()

    assert (sd["fc1.weight"] == sd2["fc1.weight"]).all()
    assert (sd["fc1.bias"] == sd2["fc1.bias"]).all()
    assert (sd["fc2.weight"] == sd2["fc2.weight"]).all()
    assert (sd["fc2.bias"] == sd2["fc2.bias"]).all()
コード例 #10
0
ファイル: client_torch_test.py プロジェクト: znreza/PySyft
def test_torch_function() -> None:
    bob = sy.VirtualMachine(name="Bob")
    client = bob.get_client()

    x = th.tensor([[-0.1, 0.1], [0.2, 0.3]])
    ptr_x = x.send(client)
    ptr_res = client.torch.zeros_like(ptr_x)
    res = ptr_res.get()

    assert (res == th.tensor([[0.0, 0.0], [0.0, 0.0]])).all()
コード例 #11
0
ファイル: node_test.py プロジェクト: znreza/PySyft
def test_send_message_from_vm_client_to_vm() -> None:

    bob_vm = sy.VirtualMachine(name="Bob")
    bob_vm_client = bob_vm.get_client()

    assert bob_vm.device is None

    with pytest.raises(AuthorizationException):
        bob_vm_client.send_immediate_msg_without_reply(msg=sy.ReprMessage(
            address=bob_vm_client.address))
コード例 #12
0
ファイル: device_test.py プロジェクト: znreza/PySyft
def test_device_init() -> None:
    bob = sy.VirtualMachine(name="Bob")
    client = bob.get_client()
    torch = client.torch

    type_str = String("cuda:0")
    str_pointer = type_str.send(client)

    device_pointer = torch.device(str_pointer)
    assert type(device_pointer).__name__ == "devicePointer"
    assert isinstance(device_pointer.id_at_location, UID)
コード例 #13
0
ファイル: serde_test.py プロジェクト: znreza/PySyft
def test_send() -> None:
    alice = sy.VirtualMachine(name="alice")
    alice_client = alice.get_client()

    syft_complex = Complex("2+3j")
    ptr = syft_complex.send(alice_client)
    # Check pointer type
    assert ptr.__class__.__name__ == "ComplexPointer"

    # Check that we can get back the object
    res = ptr.get()
    assert res == syft_complex
コード例 #14
0
ファイル: serde_test.py プロジェクト: znreza/PySyft
def test_send() -> None:
    alice = sy.VirtualMachine(name="alice")
    alice_client = alice.get_client()

    syft_float = Float(5)
    ptr = syft_float.send(alice_client)
    # Check pointer type
    assert ptr.__class__.__name__ == "FloatPointer"

    # Check that we can get back the object
    res = ptr.get()
    assert res == syft_float
コード例 #15
0
ファイル: node_test.py プロジェクト: znreza/PySyft
def test_known_child_nodes() -> None:
    bob_vm = sy.VirtualMachine(name="Bob VM")
    bob_vm_client = bob_vm.get_client()
    bob_vm.root_verify_key = bob_vm_client.verify_key  # inject 📡🔑 as 📍🗝

    bob_vm_2 = sy.VirtualMachine(name="Bob VM 2")
    bob_vm_client_2 = bob_vm_2.get_client()
    bob_vm_2.root_verify_key = bob_vm_client_2.verify_key  # inject 📡🔑 as 📍🗝

    bob_phone = sy.Device(name="Bob's iPhone")
    bob_phone_client = bob_phone.get_client()
    bob_phone.root_verify_key = bob_phone_client.verify_key  # inject 📡🔑 as 📍🗝

    bob_phone_client.register(client=bob_vm_client)

    assert len(bob_phone.known_child_nodes) == 1
    assert bob_vm in bob_phone.known_child_nodes

    bob_phone_client.register(client=bob_vm_client_2)

    assert len(bob_phone.known_child_nodes) == 2
    assert bob_vm_2 in bob_phone.known_child_nodes
コード例 #16
0
def test_string_send() -> None:
    alice = syft.VirtualMachine(name="alice")
    alice_client = alice.get_client()

    syft_string = String("Hello OpenMined!")
    ptr = syft_string.send(alice_client)

    # Check pointer type
    assert ptr.__class__.__name__ == "StringPointer"

    # Check that we can get back the object
    res = ptr.get()
    assert res == syft_string
コード例 #17
0
ファイル: statement_test.py プロジェクト: yashlamba/PySyft
def test_statement_zk_proof() -> None:
    vm = sy.VirtualMachine()
    client = vm.get_root_client()

    sy.load("zksk")

    # third party
    from zksk import DLRep
    from zksk import Secret
    from zksk import utils

    num = 2
    seed = 42
    num_sy = sy.lib.python.Int(num)
    seed_sy = sy.lib.python.Int(seed)

    # Setup: Peggy and Victor agree on two group generators.
    G, H = utils.make_generators(num=num, seed=seed)
    # Setup: generate a secret randomizer.
    r = Secret(utils.get_random_num(bits=128))

    # This is Peggy's secret bit.
    top_secret_bit = 1

    # A Pedersen commitment to the secret bit.
    C = top_secret_bit * G + r.value * H

    # Peggy's definition of the proof statement, and proof generation.
    # (The first or-clause corresponds to the secret value 0, and the second to the value 1. Because
    # the real value of the bit is 1, the clause that corresponds to zero is marked as simulated.)
    stmt = DLRep(C, r * H, simulated=True) | DLRep(C - G, r * H)
    zk_proof = stmt.prove()

    # send over the network and get back
    num_ptr = num_sy.send(client)
    seed_prt = seed_sy.send(client)
    c_ptr = C.send(client)
    zk_proof_ptr = zk_proof.send(client)

    num2 = num_ptr.get().upcast()
    seed2 = seed_prt.get().upcast()
    C2 = c_ptr.get()
    zk_proof2 = zk_proof_ptr.get()

    # Setup: get the agreed group generators.
    G, H = utils.make_generators(num=num2, seed=seed2)
    # Setup: define a randomizer with an unknown value.
    r = Secret()

    stmt = DLRep(C2, r * H) | DLRep(C2 - G, r * H)
    assert stmt.verify(zk_proof2)
コード例 #18
0
def test_secret_serde() -> None:
    vm = sy.VirtualMachine()
    client = vm.get_root_client()

    # third party
    import zksk as zk

    sy.load("zksk")

    r = zk.Secret(zk.utils.get_random_num(bits=128))
    r_ptr = r.send(client)
    r2 = r_ptr.get()

    assert r == r2
コード例 #19
0
ファイル: node_test.py プロジェクト: znreza/PySyft
def test_register_vm_on_device_succeeds() -> None:
    # Register a 🍰 with a 📱

    bob_vm = sy.VirtualMachine(name="Bob")
    bob_vm_client = bob_vm.get_client()
    bob_vm.root_verify_key = bob_vm_client.verify_key  # inject 📡🔑 as 📍🗝

    bob_phone = sy.Device(name="Bob's iPhone")
    bob_phone_client = bob_phone.get_client()
    bob_phone.root_verify_key = bob_phone_client.verify_key  # inject 📡🔑 as 📍🗝

    bob_phone_client.register(client=bob_vm_client)

    assert bob_vm.device is not None
    assert bob_vm_client.device is not None
コード例 #20
0
ファイル: node_test.py プロジェクト: znreza/PySyft
def test_register_vm_on_device_fails() -> None:

    bob_vm = sy.VirtualMachine(name="Bob")
    bob_vm_client = bob_vm.get_client()

    bob_phone = sy.Device(name="Bob's iPhone")
    bob_phone_client = bob_phone.get_client()

    with pytest.raises(AuthorizationException):
        bob_phone_client.register(client=bob_vm_client)

    assert bob_vm.device is None

    # TODO: prevent device being set when Authorization fails
    assert bob_vm_client.device is not None
コード例 #21
0
def test_list_send() -> None:
    alice = sy.VirtualMachine(name="alice")
    alice_client = alice.get_client()

    syft_list = Dict({
        String("t1"): String("test"),
        String("t2"): String("test")
    })
    ptr = syft_list.send(alice_client)
    # Check pointer type
    assert ptr.__class__.__name__ == "DictPointer"

    # Check that we can get back the object
    res = ptr.get()
    for res_el, original_el in zip(res, syft_list):
        assert res_el == original_el
コード例 #22
0
def test_list_send() -> None:
    alice = sy.VirtualMachine(name="alice")
    alice_client = alice.get_client()

    t1 = th.tensor([1, 2])
    t2 = th.tensor([1, 3])

    syft_list = List([t1, t2])
    ptr = syft_list.send(alice_client)
    # Check pointer type
    assert ptr.__class__.__name__ == "ListPointer"

    # Check that we can get back the object
    res = ptr.get()
    for res_el, original_el in zip(res, syft_list):
        assert (res_el == original_el).all()
コード例 #23
0
def test_torch_garbage_method_creates_pointer() -> None:
    """
    Test if sending a tensor and then deleting the pointer removes the object
    from the remote worker.
    """

    alice = sy.VirtualMachine(name="alice")
    alice_client = alice.get_client()

    x = th.tensor([-1, 0, 1, 2, 3, 4])
    x_ptr = x.send(alice_client)

    assert len(alice.store) == 1

    gc.disable()
    x_ptr + 2

    assert len(alice.store) == 3
コード例 #24
0
def test_torch_remote_tensor_register() -> None:
    """ Test if sending a tensor will be registered on the remote worker. """

    alice = sy.VirtualMachine(name="alice")
    alice_client = alice.get_client()

    x = th.tensor([-1, 0, 1, 2, 3, 4])
    ptr = x.send(alice_client)

    assert len(alice.store) == 1

    ptr = x.send(alice_client)
    gc.collect()

    # the previous objects get deleted because we overwrite
    # ptr - we send a message to delete that object
    assert len(alice.store) == 1

    ptr.get()
    assert len(alice.store) == 0  # Get removes the object
コード例 #25
0
def test_torch_remote_tensor_with_alias_send() -> None:
    """Test sending tensor on the remote worker with alias send method."""

    alice = sy.VirtualMachine(name="alice")
    alice_client = alice.get_client()

    x = th.tensor([-1, 0, 1, 2, 3, 4])
    ptr = x.send_to(alice_client)

    assert len(alice.store) == 1

    # TODO: Fix this from deleting the object in the store due to the variable
    # see above
    # ptr = x.send_to(alice_client)

    data = ptr.get()

    assert len(alice.store) == 0  # Get removes the object

    assert x.equal(data)  # Check if send data and received data are equal
コード例 #26
0
def test_relu_module() -> None:
    alice = sy.VirtualMachine()
    alice_client = alice.get_root_client()

    # ReLU
    relu = th.nn.ReLU(inplace=True)

    # send
    relu_ptr = relu.send(alice_client)

    # remote call
    rand_data = th.rand([1, 4])
    res_ptr = relu_ptr(rand_data)
    rand_output = res_ptr.get()
    assert rand_output.shape == th.Size((1, 4))

    relu2 = relu_ptr.get()
    assert type(relu) == type(relu2)
    rand_output2 = relu2(rand_data)
    assert (rand_output2 == rand_output).all()
コード例 #27
0
def test_torch_garbage_collect() -> None:
    """
    Test if sending a tensor and then deleting the pointer removes the object
    from the remote worker.
    """

    alice = sy.VirtualMachine(name="alice")
    alice_client = alice.get_client()

    x = th.tensor([-1, 0, 1, 2, 3, 4])
    ptr = x.send(alice_client)

    assert len(alice.store) == 1

    # "del" only decrements the counter and the garbage collector plays the role of the reaper
    del ptr

    # Make sure __del__ from Pointer is called
    gc.collect()

    assert len(alice.store) == 0
コード例 #28
0
def test_child_node_lifecycle_message_serde() -> None:
    bob_vm = sy.VirtualMachine(name="Bob")
    bob_vm_client = bob_vm.get_client()

    bob_phone = sy.Device(name="Bob's iPhone")
    bob_phone_client = bob_phone.get_client()

    # bob_phone_client.register(client=bob_vm_client)
    # generates this message
    msg = RegisterChildNodeMessage(
        lookup_id=bob_vm_client.id,  # TODO: not sure if this is needed anymore
        child_node_client_address=bob_vm_client.address,
        address=bob_phone_client.address,
    )

    blob = msg.serialize()
    msg2 = sy.deserialize(blob=blob)

    assert msg.id == msg2.id
    assert msg.address == msg2.address
    assert msg.child_node_client_address == msg2.child_node_client_address
    assert msg == msg2
コード例 #29
0
def test_signaling_answer_pull_request_message_serde() -> None:
    bob_vm = sy.VirtualMachine(name="Bob")
    target = Address(name="Alice")

    target_id = secrets.token_hex(nbytes=16)
    host_id = secrets.token_hex(nbytes=16)

    msg = AnswerPullRequestMessage(
        address=target,
        target_peer=target_id,
        host_peer=host_id,
        reply_to=bob_vm.address,
    )

    blob = msg.serialize()
    msg2 = sy.deserialize(blob=blob)

    assert msg.id == msg2.id
    assert msg.address == target
    assert msg == msg2
    assert msg2.host_peer == host_id
    assert msg2.target_peer == target_id
コード例 #30
0
def test_get_copy() -> None:

    alice = sy.VirtualMachine(name="alice")
    alice_client = alice.get_client()

    x = th.nn.Parameter(th.randn(3, 3))

    xp = x.send(alice_client)

    y = xp + xp

    assert len(alice.store._objects) == 2

    y.get_copy()

    # no deletion of the object
    assert len(alice.store._objects) == 2

    del xp
    gc.collect()

    assert len(alice.store._objects) == 1