Esempio n. 1
0
async def test_consumer_request() -> None:
    test_domain = Domain(name="test")

    webrtc_node = WebRTCConnection(node=test_domain)

    msg = ReprMessage(address=test_domain.address)
    signing_key = SigningKey.generate()
    test_domain.root_verify_key = signing_key.verify_key
    signed_msg = msg.sign(signing_key=signing_key)

    msg_bin = serialize(signed_msg, to_bytes=True)

    await webrtc_node.consumer(msg=msg_bin)
Esempio n. 2
0
async def test_set_answer_on_message() -> None:
    domain = Domain(name="test")
    webrtc = WebRTCConnection(node=domain)
    offer_payload = await webrtc._set_offer()

    answer_webrtc = WebRTCConnection(node=domain)
    await answer_webrtc._set_answer(payload=offer_payload)

    channel_methods = list(answer_webrtc.peer_connection._events.values())
    on_channel = list(channel_methods[1].values())[0]

    coro_mock = AsyncMock()
    with patch(
            "syft.grid.connections.webrtc.WebRTCConnection.consumer",
            return_value=coro_mock(),
    ) as consumer_mock:
        channel = answer_webrtc.peer_connection.createDataChannel(
            "datachannel")
        on_channel(channel)

        channel_methods = list(answer_webrtc.channel._events.values())
        on_message = list(channel_methods[1].values())[0]

        await on_message(OrderedChunk(1, DC_CHUNK_START_SIGN).save())
        assert consumer_mock.call_count == 0

        await on_message(OrderedChunk(0, b"a").save())
        assert consumer_mock.call_count == 1
Esempio n. 3
0
async def test_signaling_process() -> None:
    domain = Domain(name="test")
    webrtc = WebRTCConnection(node=domain)

    offer_payload = await webrtc._set_offer()
    offer_dict = json.loads(offer_payload)
    aiortc_session = object_from_string(offer_payload)

    assert "sdp" in offer_dict
    assert "type" in offer_dict
    assert offer_dict["type"] == "offer"
    assert isinstance(aiortc_session, RTCSessionDescription)

    answer_webrtc = WebRTCConnection(node=domain)
    answer_payload = await answer_webrtc._set_answer(payload=offer_payload)
    answer_dict = json.loads(answer_payload)
    aiortc_session = object_from_string(answer_payload)

    assert "sdp" in answer_dict
    assert "type" in answer_dict
    assert answer_dict["type"] == "answer"
    assert isinstance(aiortc_session, RTCSessionDescription)

    response = await webrtc._process_answer(payload=answer_payload)
    assert response is None
Esempio n. 4
0
async def test_init_raise_exception(monkeypatch: MonkeyPatch) -> None:
    with patch(
            "syft.grid.connections.webrtc.traceback_and_raise") as mock_logger:
        with patch("syft.grid.connections.webrtc.RTCPeerConnection",
                   side_effect=Exception()):
            domain = Domain(name="test")
            WebRTCConnection(node=domain)
            assert mock_logger.assert_called
Esempio n. 5
0
async def test_set_offer_raise_exception() -> None:
    domain = Domain(name="test")
    webrtc = WebRTCConnection(node=domain)

    with patch(
            "syft.grid.connections.webrtc.RTCPeerConnection.createDataChannel",
            side_effect=Exception(),
    ):
        with pytest.raises(Exception):
            await webrtc._set_offer()
Esempio n. 6
0
async def test_init() -> None:
    domain = Domain(name="test")
    webrtc = WebRTCConnection(node=domain)
    assert webrtc is not None
    assert webrtc.node == domain
    assert webrtc.loop is not None
    assert isinstance(webrtc.producer_pool, asyncio.Queue)
    assert isinstance(webrtc.consumer_pool, asyncio.Queue)
    assert isinstance(webrtc.peer_connection, RTCPeerConnection)
    assert not webrtc._client_address
Esempio n. 7
0
async def test_close_raise_exception() -> None:
    domain = Domain(name="test")
    webrtc = WebRTCConnection(node=domain)

    with patch(
            "syft.grid.connections.webrtc.traceback_and_raise") as mock_logger:
        with patch("syft.grid.connections.webrtc.RTCDataChannel.close",
                   side_effect=Exception()):
            webrtc.close()
            assert mock_logger.called
Esempio n. 8
0
async def test_close() -> None:
    domain = Domain(name="test")
    webrtc = WebRTCConnection(node=domain)
    await webrtc._set_offer()

    with patch(
            "syft.grid.connections.webrtc.RTCDataChannel.send") as send_mock:
        with patch(
                "syft.grid.connections.webrtc.WebRTCConnection._finish_coroutines"
        ) as finish_mock:
            webrtc.close()
            assert send_mock.call_count == 1
            assert finish_mock.call_count == 1
Esempio n. 9
0
async def test_set_answer_raise_exception() -> None:
    domain = Domain(name="test")
    webrtc = WebRTCConnection(node=domain)
    offer_payload = await webrtc._set_offer()

    # FIXME: Nahua is not happy with this test because it "indirectly" triggered exception
    # https://github.com/OpenMined/PySyft/issues/5126
    with patch(
            "syft.grid.connections.webrtc.traceback_and_raise") as mock_logger:
        with pytest.raises(Exception):
            # This would fail because 'have-local-offer' is applied
            await webrtc._set_answer(payload=offer_payload)
        assert mock_logger.called
Esempio n. 10
0
async def test_set_offer_on_open() -> None:
    domain = Domain(name="test")
    webrtc = WebRTCConnection(node=domain)
    await webrtc._set_offer()

    channel_methods = list(webrtc.channel._events.values())
    on_open = list(channel_methods[1].values())[0]

    coro_mock = AsyncMock()
    with patch(
            "syft.grid.connections.webrtc.WebRTCConnection.producer",
            return_value=coro_mock(),
    ) as producer_mock:
        await on_open()
        assert producer_mock.call_count == 1
Esempio n. 11
0
async def test_set_offer_on_message() -> None:
    domain = Domain(name="test")
    webrtc = WebRTCConnection(node=domain)
    await webrtc._set_offer()

    channel_methods = list(webrtc.channel._events.values())
    on_message = list(channel_methods[2].values())[0]

    coro_mock = AsyncMock()
    with patch(
            "syft.grid.connections.webrtc.WebRTCConnection.consumer",
            return_value=coro_mock(),
    ) as consumer_mock:
        await on_message(OrderedChunk(1, DC_CHUNK_START_SIGN).save())
        assert consumer_mock.call_count == 0

        await on_message(OrderedChunk(0, b"a").save())
        assert consumer_mock.call_count == 1
Esempio n. 12
0
async def test_set_answer_on_datachannel() -> None:
    domain = Domain(name="test")
    webrtc = WebRTCConnection(node=domain)
    offer_payload = await webrtc._set_offer()

    answer_webrtc = WebRTCConnection(node=domain)
    await answer_webrtc._set_answer(payload=offer_payload)

    channel_methods = list(answer_webrtc.peer_connection._events.values())
    on_datachannel = list(channel_methods[1].values())[0]

    coro_mock = AsyncMock()
    with patch(
            "syft.grid.connections.webrtc.WebRTCConnection.producer",
            return_value=coro_mock(),
    ) as producer_mock:
        channel = answer_webrtc.peer_connection.createDataChannel(
            "datachannel")
        on_datachannel(channel)
        assert producer_mock.call_count == 1
Esempio n. 13
0
async def test_set_offer_sets_channel() -> None:
    domain = Domain(name="test")
    webrtc = WebRTCConnection(node=domain)
    await webrtc._set_offer()
    assert isinstance(webrtc.channel, RTCDataChannel)
    assert webrtc.channel.bufferedAmountLowThreshold == 4 * DC_MAX_CHUNK_SIZE
Esempio n. 14
0
def test_init_without_event_loop() -> None:
    domain = Domain(name="test")
    webrtc = WebRTCConnection(node=domain)
    assert webrtc is not None
Esempio n. 15
0
import json
import pickle

# third party
from flask import Flask
from flask import request

# syft absolute
from syft.core.common.message import ImmediateSyftMessageWithReply
from syft.core.common.message import ImmediateSyftMessageWithoutReply
from syft.core.node.domain.domain import Domain

app = Flask(__name__)


domain = Domain(name="ucsf")


@app.route("/")
def get_client() -> str:
    client_metadata = domain.get_metadata_for_client()
    return pickle.dumps(client_metadata).hex()


@app.route("/recv", methods=["POST"])
def recv() -> str:
    hex_msg = request.get_json()["data"]
    msg = pickle.loads(binascii.unhexlify(hex_msg))  # nosec # TODO make less insecure
    reply = None
    print(str(msg))
    if isinstance(msg, ImmediateSyftMessageWithReply):