async def test_call_peer_different_processes():
    handler_name = "square"
    test_input = 2

    server_side, client_side = mp.Pipe()
    response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
    response_received.value = 0

    proc = mp.Process(target=server_target,
                      args=(handler_name, server_side, client_side,
                            response_received))
    proc.start()

    peer_id = client_side.recv()
    peer_port = client_side.recv()

    nodes = [bootstrap_addr(peer_port, peer_id)]
    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
    client_pid = client._child.pid
    assert is_process_running(client_pid)

    await client.wait_for_at_least_n_peers(1)

    test_input_msgp = MSGPackSerializer.dumps(2)
    result_msgp = await client.call_peer_handler(peer_id, handler_name,
                                                 test_input_msgp)
    result = MSGPackSerializer.loads(result_msgp)
    assert np.allclose(result, test_input**2)
    response_received.value = 1

    await client.shutdown()
    assert not is_process_running(client_pid)

    proc.join()
async def test_call_peer_single_process(test_input,
                                        expected,
                                        handle,
                                        handler_name="handle"):
    server = await P2P.create()
    server_pid = server._child.pid
    await server.add_stream_handler(handler_name, handle)
    assert is_process_running(server_pid)

    nodes = bootstrap_from([server])
    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
    client_pid = client._child.pid
    assert is_process_running(client_pid)

    await client.wait_for_at_least_n_peers(1)

    test_input_msgp = MSGPackSerializer.dumps(test_input)
    result_msgp = await client.call_peer_handler(server.id, handler_name,
                                                 test_input_msgp)
    result = MSGPackSerializer.loads(result_msgp)
    assert result == expected

    await server.stop_listening()
    await server.shutdown()
    assert not is_process_running(server_pid)

    await client.shutdown()
    assert not is_process_running(client_pid)
Beispiel #3
0
def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
    """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
    schema_dicts = [{
        field_name: str(field_value)
        for field_name, field_value in asdict(
            TensorDescriptor.from_tensor(tensor)).items()
    } for tensor in tensors]
    return DHTID.generate(
        source=MSGPackSerializer.dumps(schema_dicts)).to_bytes()
def handle_add_torch(args):
    args = MSGPackSerializer.loads(args)
    tensor = runtime_pb2.Tensor()
    tensor.ParseFromString(args[0])
    result = deserialize_torch_tensor(tensor)

    for i in range(1, len(args)):
        tensor = runtime_pb2.Tensor()
        tensor.ParseFromString(args[i])
        result = result + deserialize_torch_tensor(tensor)

    return serialize_torch_tensor(result).SerializeToString()
Beispiel #5
0
    def generate(cls, source: Optional[Any] = None, nbits: int = 255):
        """
        Generates random uid based on SHA1

        :param source: if provided, converts this value to bytes and uses it as input for hashing function;
            by default, generates a random dhtid from :nbits: random bits
        """
        source = random.getrandbits(nbits).to_bytes(
            nbits, byteorder='big') if source is None else source
        source = MSGPackSerializer.dumps(source) if not isinstance(
            source, bytes) else source
        raw_uid = cls.HASH_FUNC(source).digest()
        return cls(int(raw_uid.hex(), 16))
def test_serialize_tuple():
    test_pairs = (
        ((1, 2, 3), [1, 2, 3]),
        (('1', False, 0), ['1', False, 0]),
        (('1', False, 0), ('1', 0, 0)),
        (('1', b'qq', (2, 5, '0')), ['1', b'qq', (2, 5, '0')]),
    )

    for first, second in test_pairs:
        assert MSGPackSerializer.loads(MSGPackSerializer.dumps(first)) == first
        assert MSGPackSerializer.loads(
            MSGPackSerializer.dumps(second)) == second
        assert MSGPackSerializer.dumps(first) != MSGPackSerializer.dumps(
            second)
async def test_call_peer_error(replicate, handler_name="handle"):
    server_primary = await P2P.create()
    server = await replicate_if_needed(server_primary, replicate)
    await server.add_stream_handler(handler_name, handle_add_torch_with_exc)

    nodes = bootstrap_from([server])
    client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
    client = await replicate_if_needed(client_primary, replicate)

    await client.wait_for_at_least_n_peers(1)

    inp = [
        serialize_torch_tensor(i).SerializeToString()
        for i in [torch.zeros((2,
                               3)), torch.zeros((3, 2))]
    ]
    inp_msgp = MSGPackSerializer.dumps(inp)
    result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
    assert result == b'something went wrong :('

    await server.stop_listening()
    await server_primary.shutdown()
    await client_primary.shutdown()
async def test_call_peer_torch_add(test_input,
                                   expected,
                                   handler_name="handle"):
    handle = handle_add_torch
    server = await P2P.create()
    await server.add_stream_handler(handler_name, handle)

    nodes = bootstrap_from([server])
    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)

    await client.wait_for_at_least_n_peers(1)

    inp = [serialize_torch_tensor(i).SerializeToString() for i in test_input]
    inp_msgp = MSGPackSerializer.dumps(inp)
    result_pb = await client.call_peer_handler(server.id, handler_name,
                                               inp_msgp)
    result = runtime_pb2.Tensor()
    result.ParseFromString(result_pb)
    result = deserialize_torch_tensor(result)
    assert torch.allclose(result, expected)

    await server.stop_listening()
    await server.shutdown()
    await client.shutdown()
def handle_add(args):
    args = MSGPackSerializer.loads(args)
    result = args[0]
    for i in range(1, len(args)):
        result = result + args[i]
    return MSGPackSerializer.dumps(result)
Beispiel #10
0
def handle_square(x):
    x = MSGPackSerializer.loads(x)
    return MSGPackSerializer.dumps(x**2)
Beispiel #11
0
    assert err.message == 'boom'

    await server.stop_listening()
    await server.shutdown()
    await client.shutdown()


@pytest.mark.parametrize("test_input,expected,handle", [
    pytest.param(10, 100, handle_square, id="square_integer"),
    pytest.param((1, 2), 3, handle_add, id="add_integers"),
    pytest.param(
        ([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"),
    pytest.param(
        2,
        8,
        lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x)**3),
        id="lambda")
])
@pytest.mark.asyncio
async def test_call_peer_single_process(test_input,
                                        expected,
                                        handle,
                                        handler_name="handle"):
    server = await P2P.create()
    server_pid = server._child.pid
    await server.add_stream_handler(handler_name, handle)
    assert is_process_running(server_pid)

    nodes = bootstrap_from([server])
    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
    client_pid = client._child.pid
Beispiel #12
0
 def _serialize_record(self, record: DHTRecord) -> bytes:
     return MSGPackSerializer.dumps(dataclasses.astuple(record))
Beispiel #13
0
 async def receive_msgpack(reader):
     return MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
Beispiel #14
0
 async def send_msgpack(data, writer):
     raw_data = MSGPackSerializer.dumps(data)
     await P2P.send_raw_data(raw_data, writer)