示例#1
0
def test_recv_msg():
    """Tests the recv_msg command with 2 tests

    The first test uses recv_msg to send an object to alice.

    The second test uses recv_msg to request the object
    previously sent to alice."""

    # TEST 1: send tensor to alice

    # create a worker to send data to
    worker_id = sy.ID_PROVIDER.pop()
    alice = VirtualWorker(sy.torch.hook, id=f"alice{worker_id}")

    # create object to send
    obj = torch.Tensor([100, 100])

    # create/serialize message
    message = ObjectMessage(obj)
    bin_msg = serde.serialize(message)

    # have alice receive message
    alice.recv_msg(bin_msg)

    # ensure that object is now in alice's registry
    assert obj.id in alice.object_store._objects

    # Test 2: get tensor back from alice

    # Create message: Get tensor from alice
    message = ObjectRequestMessage(obj.id, None, "")

    # serialize message
    bin_msg = serde.serialize(message)

    # call receive message on alice
    resp = alice.recv_msg(bin_msg)

    obj_2 = sy.serde.deserialize(resp)

    # assert that response is correct type
    assert type(resp) == bytes

    # ensure that the object we receive is correct
    assert obj_2.id == obj.id
示例#2
0
def test_recv_msg():
    """Tests the recv_msg command with 2 tests

    The first test uses recv_msg to send an object to alice.

    The second test uses recv_msg to request the object
    previously sent to alice."""

    # TEST 1: send tensor to alice

    # create a worker to send data to
    alice = VirtualWorker(sy.torch.hook)

    # create object to send
    obj = torch.Tensor([100, 100])

    # create/serialize message
    msg = (MSGTYPE.OBJ, obj)
    bin_msg = serde.serialize(msg)

    # have alice receive message
    alice.recv_msg(bin_msg)

    # ensure that object is now in alice's registry
    assert obj.id in alice._objects

    # Test 2: get tensor back from alice

    # Create message: Get tensor from alice
    msg = (MSGTYPE.OBJ_REQ, obj.id)

    # serialize message
    bin_msg = serde.serialize(msg)

    # call receive message on alice
    resp = alice.recv_msg(bin_msg)

    obj_2 = serde.deserialize(resp)

    # assert that response is correct type
    assert type(resp) == bytes

    # ensure that the object we receive is correct
    assert obj_2.id == obj.id