예제 #1
0
    def get_messages(self, topic: int, force: Optional[bool] = False) -> iter:
        """Return a list of bytes representing the messages to send.

        :param topic: The topic associated to the image.
        :param force: Specify if tje messages must be re-computed.
        :return messages: An iterator containing the the messages to send as bytes.
        """
        if self.async_msg_generation and (force is False):
            return self.messages
        img_split = self.split_image()
        to_msg = lambda enum: UDPMessage(code=codes.VIDEO_STREAM,
                                         payload=enum[1],
                                         topic=topic,
                                         subtopic=enum[0] + 1).to_bytes()
        img_messages = map(to_msg, enumerate(img_split))
        header = ImageManager.get_header_msg(
            topic,
            math.ceil(
                np.array(self.current_image.shape).prod() /
                self.max_packet_size),
            int(np.array(self.current_image.shape).prod()),
            self.current_image.shape[0],
            self.current_image.shape[1],
            self.get_pixel_size(),
            encoding=self.encoding)
        return chain([header], img_messages)
예제 #2
0
def test_next_message_return_authentication_required_message_when_connection_step_4_and_role_is_server_with_password(
):
    # Given
    password_to_derive = b"test"
    password_salt = os.urandom(16)
    derived_password = derive_password_scrypt(
        password_salt=password_salt, password_to_derive=password_to_derive)
    allowed_authentication_method = ["password"]
    authentication_information_server = {
        "password": {
            Handshake.PASSWORD_AUTH_METHOD_DERIVED_PASSWORD_KEY:
            derived_password,
            Handshake.PASSWORD_AUTH_METHOD_SALT_KEY: password_salt
        }
    }
    server = Handshake(
        role=Handshake.SERVER,
        allowed_authentication_methods=allowed_authentication_method,
        authentication_information=authentication_information_server)
    client = Handshake(
        role=Handshake.CLIENT,
        allowed_authentication_methods=allowed_authentication_method)

    expected_message = UDPMessage(
        code=codes.HANDSHAKE, topic=Handshake.AUTHENTICATION_REQUIRED_TOPIC)

    server.add_message(client.next_message())
    client.add_message(server.next_message())
    server.add_message(client.next_message())
    # When
    result = server.next_message()

    # Then
    assert result.msg_id == expected_message.msg_id
    assert result.topic == expected_message.topic
예제 #3
0
    def get_header_msg(topic: int,
                       nb_packet: int,
                       total_bytes: int,
                       height: int,
                       length: int,
                       pixel_size: int,
                       encoding: Optional[int] = 0) -> bytes:
        """Return a UDPMessage with image metadata.

        :param topic: The topic associated to the image.
        :param nb_packet: The total number of data packet that will be send.
        :param total_bytes: The total number of bytes of the image.
        :param height: The height of the image.
        :param length: The length of the image.
        :param pixel_size: The size of a pixel.
        :param encoding: The encoding of the pixel (default 0 = None).
        :return header_msg: The UDPMessage containing image metadata.
        """
        return UDPMessage(
            code=codes.VIDEO_STREAM,
            topic=topic,
            subtopic=ImageManager.NB_MSG_HEADER,
            payload=nb_packet.to_bytes(ImageManager.NB_PACKET_SIZE, 'little') +
            total_bytes.to_bytes(ImageManager.TOTAL_BYTES_SIZE, 'little') +
            height.to_bytes(ImageManager.HEIGHT_SIZE, 'little') +
            length.to_bytes(ImageManager.LENGTH_SIZE, 'little') +
            pixel_size.to_bytes(ImageManager.SIZE_PIXEL_SIZE, 'little') +
            encoding.to_bytes(ImageManager.ENCODING_SIZE,
                              'little')).to_bytes()
예제 #4
0
    def _nxt_msg_srv_approve_connection(self) -> UDPMessage:
        """Return connection approve message .

        :return next_message: A UDPMessage to send to remote host to continue handshake process.
        """
        self._connection_status = Handshake.CONNECTION_STATUS_APPROVED
        return UDPMessage(code=codes.HANDSHAKE,
                          topic=Handshake.CONNECTION_APPROVED_TOPIC)
예제 #5
0
    def _nxt_msg_connection_failed(self) -> UDPMessage:
        """Return connection failed message .

        :return next_message: A UDPMessage to send to remote host to continue handshake process.
        """
        self._connection_status = Handshake.CONNECTION_STATUS_FAILED
        return UDPMessage(code=codes.HANDSHAKE,
                          topic=Handshake.CONNECTION_FAILED_TOPIC)
예제 #6
0
def test_new_udp_message_throw_error_when_input_message_nb_too_big():
    # Given
    msg_nb = bytes([1, 0, 0, 0, 0, 0])

    # When

    # Then
    with pytest.raises(ValueError):
        UDPMessage(subtopic=msg_nb)
예제 #7
0
def test_new_udp_message_throw_error_when_input_topic_too_big():
    # Given
    topic = bytes([1, 0, 0, 0, 0, 0])

    # When

    # Then
    with pytest.raises(ValueError):
        UDPMessage(topic=topic)
예제 #8
0
def test_new_udp_message_throw_error_when_input_message_id_too_big():
    # Given
    msg_id = bytes([1, 0, 0, 0, 0, 0])

    # When

    # Then
    with pytest.raises(ValueError):
        UDPMessage(code=msg_id)
예제 #9
0
def test_new_udp_message_created_with_correct_msg_id_when_type_bytes():
    # Given
    msg_id = bytes([48, 48, 48, 48])

    # When
    msg = UDPMessage(code=msg_id)

    # Then
    assert msg.msg_id == msg_id
예제 #10
0
def test_new_udp_message_throw_error_when_input_payload_too_big():
    # Given
    payload = bytes([0] * (UDPMessage.PAYLOAD_MAX_SIZE + 1))

    # When

    # Then
    with pytest.raises(ValueError):
        UDPMessage(payload=payload)
예제 #11
0
def test_new_udp_message_created_with_correct_payload_when_type_bytes():
    # Given
    payload = bytes([48, 48, 48, 48])

    # When
    msg = UDPMessage(payload=payload)

    # Then
    assert msg.payload == payload
예제 #12
0
def test_new_udp_message_created_with_correct_topic_when_type_bytes():
    # Given
    topic = bytes([1, 0, 0, 0])

    # When
    msg = UDPMessage(topic=topic)

    # Then
    assert msg.topic == topic
예제 #13
0
def test_new_udp_message_created_with_correct_message_nb_when_type_bytes():
    # Given
    message_nb = bytes(UDPMessage.MSG_NUMBER_LENGTH * [1])

    # When
    msg = UDPMessage(subtopic=message_nb)

    # Then
    assert msg.message_nb == message_nb
예제 #14
0
def test_new_udp_message_created_with_correct_topic_length_when_input_too_small(
):
    # Given
    topic = bytes([1, 0])

    # When
    msg = UDPMessage(topic=topic)

    # Then
    assert len(msg.topic) == UDPMessage.TOPIC_LENGTH
예제 #15
0
def test_new_udp_message_created_with_correct_message_id_length_when_input_too_small(
):
    # Given
    msg_id = bytes([1, 0])

    # When
    msg = UDPMessage(code=msg_id)

    # Then
    assert len(msg.msg_id) == UDPMessage.MSG_ID_LENGTH
예제 #16
0
def test_new_udp_message_created_with_correct_message_nb_length_when_input_too_small(
):
    # Given
    msg_nb = bytes([1])

    # When
    msg = UDPMessage(subtopic=msg_nb)

    # Then
    assert len(msg.message_nb) == UDPMessage.MSG_NUMBER_LENGTH
예제 #17
0
def test_new_udp_message_created_with_correct_payload_when_string():
    # Given
    expected_payload = bytes([48, 48, 48, 48])
    payload = "0000"

    # When
    msg = UDPMessage(payload=payload)

    # Then
    assert msg.payload == expected_payload
예제 #18
0
def test_new_udp_message_created_with_correct_message_nb_when_type_int():
    # Given
    expected_message_nb = bytes([1] + (UDPMessage.MSG_NUMBER_LENGTH - 1) * [0])
    message_nb = 1

    # When
    msg = UDPMessage(subtopic=message_nb)

    # Then
    assert msg.message_nb == expected_message_nb
예제 #19
0
def test_new_udp_message_created_with_correct_time_creation():
    # Given
    expected_time_creation = int(time.time() * 1_000_000)

    # When
    msg = UDPMessage()

    # Then
    assert expected_time_creation - 2_000_000 < int.from_bytes(
        msg.time_creation, 'little') < expected_time_creation + 2_000_000
예제 #20
0
def test_new_udp_message_created_with_correct_msg_id_when_type_int():
    # Given
    expected_msg_id = bytes([1, 0, 0, 0])
    msg_id = 1

    # When
    msg = UDPMessage(code=msg_id)

    # Then
    assert msg.msg_id == expected_msg_id
예제 #21
0
def test_add_messages_correctly_add_messages_to_rcv_messages():
    # Given
    nb_packet = 2
    total_bytes = 50
    pixel_size = 3
    creation = 1000
    height = 10
    length = 10
    vt = VideoTopic(nb_packet=nb_packet, total_bytes=total_bytes, height=height, length=length, pixel_size=pixel_size,
                    time_creation=creation)
    test_msg1 = UDPMessage(subtopic=1)
    test_msg2 = UDPMessage(subtopic=2, payload=bytes([1]))

    # When
    vt.add_message(test_msg1)
    vt.add_message(test_msg2)

    # Then
    assert collections.Counter(list(vt.rcv_messages[0].payload)) == collections.Counter(list(test_msg1.payload))
    assert collections.Counter(list(vt.rcv_messages[1].payload)) == collections.Counter(list(test_msg2.payload))
예제 #22
0
    def _nxt_msg_clt_connection_request(self) -> UDPMessage:
        """Return connection request message.

        :return next_message: A UDPMessage to send to remote host to continue handshake process.
        """
        payload = str.encode(
            json.dumps({
                Handshake.PROTOCOL_VERSIONS_AVAILABLE_KEY_NAME:
                self._allowed_protocol_versions
            }), "utf8")
        return UDPMessage(code=codes.HANDSHAKE,
                          topic=Handshake.CONNECTION_REQUEST_TOPIC,
                          payload=payload)
예제 #23
0
    def _nxt_msg_srv_authentication_required(self) -> UDPMessage:
        """Return authentication approved message .

        :return next_message: A UDPMessage to send to remote host to continue handshake process.
        """
        payload = str.encode(
            json.dumps({
                Handshake.AUTHENTICATION_METHODS_AVAILABLE_KEY_NAME:
                self._allowed_authentication_methods
            }), "utf8")
        return UDPMessage(code=codes.HANDSHAKE,
                          topic=Handshake.AUTHENTICATION_REQUIRED_TOPIC,
                          payload=payload)
예제 #24
0
def test_check_crc_returns_true_when_crc_correct():
    # Given
    message_nb = bytes([0, 0])
    payload = bytes([48, 48, 48, 48])
    msg_id = bytes([48, 48, 48, 48])
    topic = bytes([1, 0, 0, 0])

    # When
    msg = UDPMessage(subtopic=message_nb,
                     payload=payload,
                     code=msg_id,
                     topic=topic)

    # Then
    assert msg.check_crc() is True
예제 #25
0
    def _nxt_msg_clt_key_share(self) -> UDPMessage:
        """Return client key share message.

        :return next_message: A UDPMessage to send to remote host to continue handshake process.
        """
        public_bytes = bytes.decode(
            self._private_key.public_key().public_bytes(
                encoding=serialization.Encoding.PEM,
                format=serialization.PublicFormat.SubjectPublicKeyInfo),
            "ascii")
        payload = str.encode(
            json.dumps({Handshake.CLIENT_PUBLIC_KEY_KEY_NAME: public_bytes}),
            "utf8")
        return UDPMessage(code=codes.HANDSHAKE,
                          topic=Handshake.CLIENT_KEY_SHARE_TOPIC,
                          payload=payload)
예제 #26
0
def test_to_bytes_returns_full_message_as_bytes():
    # Given
    message_nb = bytes([0, 0])
    payload = bytes([48, 48, 48, 48])
    msg_id = bytes([48, 48, 48, 48])
    topic = bytes([1, 0, 0, 0])

    # When
    msg = UDPMessage(subtopic=message_nb,
                     payload=payload,
                     code=msg_id,
                     topic=topic)
    expected_result = msg.full_content + msg.crc

    # Then
    assert msg.to_bytes() == expected_result
예제 #27
0
def test_from_bytes_returns_none_if_message_is_corrupted():
    # Given
    message_nb = bytes([2, 0])
    payload = bytes([49, 49, 49, 49])
    msg_id = bytes([48, 48, 48, 48])
    topic = bytes([1, 0, 0, 0])

    # When
    msg = UDPMessage(subtopic=message_nb,
                     payload=payload,
                     code=msg_id,
                     topic=topic)
    msg.crc = bytes()
    result = UDPMessage.from_bytes(msg.to_bytes())

    # Then
    assert result is None
예제 #28
0
def test_new_udp_message_created_with_correct_crc():
    # Given
    message_nb = bytes([0, 0])
    payload = bytes([48, 48, 48, 48])
    msg_id = bytes([48, 48, 48, 48])
    topic = bytes([1, 0, 0, 0])

    # When
    msg = UDPMessage(subtopic=message_nb,
                     payload=payload,
                     code=msg_id,
                     topic=topic)
    full_content = msg.msg_id + msg.time_creation + msg.topic + msg.message_nb + msg.payload
    expected_crc = zlib.crc32(full_content).to_bytes(UDPMessage.CRC_LENGTH,
                                                     'little')

    # Then
    assert msg.crc == expected_crc
예제 #29
0
def test_add_messages_set_rcv_error_to_true_if_message_nb_greater_than_nb_packet():
    # Given
    nb_packet = 2
    total_bytes = 50
    pixel_size = 3
    creation = 1000
    height = 10
    length = 10
    vt = VideoTopic(nb_packet=nb_packet, total_bytes=total_bytes, height=height, length=length, pixel_size=pixel_size,
                    time_creation=creation)
    message_nb = 4
    test_msg = UDPMessage(subtopic=message_nb)

    # When
    vt.add_message(test_msg)

    # Then
    assert vt.rcv_error is True
예제 #30
0
def test_total_bytes_correct_return_true_if_expected_number_of_bytes_is_the_same_than_the_received_number():
    # Given
    nb_packet = 2
    total_bytes = 50
    pixel_size = 3
    creation = 1000
    height = 10
    length = 10
    vt = VideoTopic(nb_packet=nb_packet, total_bytes=total_bytes, height=height, length=length, pixel_size=pixel_size,
                    time_creation=creation)

    for i in range(nb_packet):
        vt.add_message(UDPMessage(subtopic=i + 1, payload=bytes(25 * [0])))

    # When
    result = vt.total_bytes_correct()

    # Then
    assert result