예제 #1
0
def test_message_key_limits():  # Note: slow test
    alice_session_record, bob_session_record = initialize_sessions_v3()

    alice_address = address.ProtocolAddress("+14159999999", 1)
    bob_address = address.ProtocolAddress("+14158888888", 1)

    alice_identity_key_pair = identity_key.IdentityKeyPair.generate()
    bob_identity_key_pair = identity_key.IdentityKeyPair.generate()
    alice_registration_id = 1  # TODO: generate these
    bob_registration_id = 2
    alice_store = storage.InMemSignalProtocolStore(alice_identity_key_pair,
                                                   alice_registration_id)
    bob_store = storage.InMemSignalProtocolStore(bob_identity_key_pair,
                                                 bob_registration_id)

    alice_store.store_session(bob_address, alice_session_record)
    bob_store.store_session(alice_address, bob_session_record)

    MAX_MESSAGE_KEYS = 2000
    TOO_MANY_MESSAGES = MAX_MESSAGE_KEYS + 300

    inflight = []

    for i in range(TOO_MANY_MESSAGES):
        msg = f"It's over {i}"
        inflight.append(
            session_cipher.message_encrypt(alice_store, bob_address,
                                           msg.encode("utf8")))

    assert (session_cipher.message_decrypt(
        bob_store, alice_address, inflight[1000]) == b"It's over 1000")
    assert session_cipher.message_decrypt(
        bob_store, alice_address,
        inflight[TOO_MANY_MESSAGES -
                 1]) == f"It's over {TOO_MANY_MESSAGES - 1}".encode("utf8")

    with pytest.raises(SignalProtocolException,
                       match="message with old counter"):
        session_cipher.message_decrypt(bob_store, alice_address, inflight[5])
예제 #2
0
def run_interaction(
    alice_store: storage.InMemSignalProtocolStore,
    alice_address: address.ProtocolAddress,
    bob_store: storage.InMemSignalProtocolStore,
    bob_address: address.ProtocolAddress,
):

    alice_ptext = b"It's rabbit season"
    alice_message = session_cipher.message_encrypt(
        alice_store, bob_address, alice_ptext
    )

    assert alice_message.message_type() == 2  # CiphertextMessageType::Whisper => 2
    assert (
        session_cipher.message_decrypt(bob_store, alice_address, alice_message)
        == alice_ptext
    )

    bob_ptext = b"It's duck season"
    bob_message = session_cipher.message_encrypt(bob_store, alice_address, bob_ptext)

    assert bob_message.message_type() == 2  # CiphertextMessageType::Whisper => 2
    assert (
        session_cipher.message_decrypt(alice_store, bob_address, bob_message)
        == bob_ptext
    )

    for i in range(10):
        alice_ptext = f"A->B message {i}"
        alice_message = session_cipher.message_encrypt(
            alice_store, bob_address, alice_ptext.encode("utf8")
        )
        assert alice_message.message_type() == 2  # CiphertextMessageType::Whisper => 2
        assert session_cipher.message_decrypt(
            bob_store, alice_address, alice_message
        ) == alice_ptext.encode("utf8")

    for i in range(10):
        bob_ptext = f"B->A message {i}"
        bob_message = session_cipher.message_encrypt(
            bob_store, alice_address, bob_ptext.encode("utf8")
        )
        assert bob_message.message_type() == 2  # CiphertextMessageType::Whisper => 2
        assert session_cipher.message_decrypt(
            alice_store, bob_address, bob_message
        ) == bob_ptext.encode("utf8")

    alice_ooo_messages = []

    for i in range(10):
        alice_ptext = f"A->B OOO message {i}"
        alice_message = session_cipher.message_encrypt(
            alice_store, bob_address, alice_ptext.encode("utf8")
        )
        alice_ooo_messages.append((alice_ptext, alice_message))

    for i in range(10):
        alice_ptext = f"A->B post-OOO message {i}"
        alice_message = session_cipher.message_encrypt(
            alice_store, bob_address, alice_ptext.encode("utf8")
        )
        assert alice_message.message_type() == 2  # CiphertextMessageType::Whisper => 2
        assert session_cipher.message_decrypt(
            bob_store, alice_address, alice_message
        ) == alice_ptext.encode("utf8")

    for i in range(10):
        bob_ptext = f"B->A message post-OOO {i}"
        bob_message = session_cipher.message_encrypt(
            bob_store, alice_address, bob_ptext.encode("utf8")
        )
        assert bob_message.message_type() == 2  # CiphertextMessageType::Whisper => 2
        assert session_cipher.message_decrypt(
            alice_store, bob_address, bob_message
        ) == bob_ptext.encode("utf8")

    ## Now we check that messages can be decrypted when delivered out of order
    for (ptext, ctext) in alice_ooo_messages:
        assert session_cipher.message_decrypt(
            bob_store, alice_address, ctext
        ) == ptext.encode("utf8")
예제 #3
0
def run_session_interaction(alice_session, bob_session):
    alice_address = address.ProtocolAddress("+14159999999", 1)
    bob_address = address.ProtocolAddress("+14158888888", 1)

    alice_identity_key_pair = identity_key.IdentityKeyPair.generate()
    bob_identity_key_pair = identity_key.IdentityKeyPair.generate()

    alice_registration_id = 1  # TODO: generate these
    bob_registration_id = 2

    alice_store = storage.InMemSignalProtocolStore(
        alice_identity_key_pair, alice_registration_id
    )
    bob_store = storage.InMemSignalProtocolStore(
        bob_identity_key_pair, bob_registration_id
    )

    alice_store.store_session(bob_address, alice_session)
    bob_store.store_session(alice_address, bob_session)

    alice_plaintext = b"This is Alice's message"
    alice_ciphertext = session_cipher.message_encrypt(
        alice_store, bob_address, alice_plaintext
    )
    bob_decrypted = session_cipher.message_decrypt(
        bob_store, alice_address, alice_ciphertext
    )
    assert bob_decrypted == alice_plaintext

    bob_plaintext = b"This is Bob's reply"

    bob_ciphertext = session_cipher.message_encrypt(
        bob_store, alice_address, bob_plaintext
    )
    alice_decrypted = session_cipher.message_decrypt(
        alice_store, bob_address, bob_ciphertext
    )
    assert alice_decrypted == bob_plaintext

    ALICE_MESSAGE_COUNT = 50
    BOB_MESSAGE_COUNT = 50

    alice_messages = []

    for i in range(ALICE_MESSAGE_COUNT):
        ptext = f"смерть за смерть {i}"
        ctext = session_cipher.message_encrypt(
            alice_store, bob_address, ptext.encode("utf8")
        )
        alice_messages.append((ptext, ctext))

    random.shuffle(alice_messages)

    for i in range(ALICE_MESSAGE_COUNT // 2):
        ptext = session_cipher.message_decrypt(
            bob_store, alice_address, alice_messages[i][1]
        )
        assert ptext.decode("utf8") == alice_messages[i][0]

    bob_messages = []

    for i in range(BOB_MESSAGE_COUNT):
        ptext = f"Relax in the safety of your own delusions. {i}"
        ctext = session_cipher.message_encrypt(
            bob_store, alice_address, ptext.encode("utf8")
        )
        bob_messages.append((ptext, ctext))

    random.shuffle(bob_messages)

    for i in range(BOB_MESSAGE_COUNT // 2):
        ptext = session_cipher.message_decrypt(
            alice_store, bob_address, bob_messages[i][1]
        )
        assert ptext.decode("utf8") == bob_messages[i][0]

    for i in range(ALICE_MESSAGE_COUNT // 2, ALICE_MESSAGE_COUNT):
        ptext = session_cipher.message_decrypt(
            bob_store, alice_address, alice_messages[i][1]
        )
        assert ptext.decode("utf8") == alice_messages[i][0]

    for i in range(BOB_MESSAGE_COUNT // 2, BOB_MESSAGE_COUNT):
        ptext = session_cipher.message_decrypt(
            alice_store, bob_address, bob_messages[i][1]
        )
        assert ptext.decode("utf8") == bob_messages[i][0]
예제 #4
0
def test_simultaneous_initiate_lost_message_repeated_messages():
    alice_address = address.ProtocolAddress("+14151111111", 1)
    bob_address = address.ProtocolAddress("+14151111112", 1)

    alice_identity_key_pair = identity_key.IdentityKeyPair.generate()
    bob_identity_key_pair = identity_key.IdentityKeyPair.generate()
    alice_registration_id = 1  # TODO: generate these
    bob_registration_id = 2
    alice_store = storage.InMemSignalProtocolStore(alice_identity_key_pair,
                                                   alice_registration_id)
    bob_store = storage.InMemSignalProtocolStore(bob_identity_key_pair,
                                                 bob_registration_id)

    bob_pre_key_bundle = create_pre_key_bundle(bob_store)

    session.process_prekey_bundle(
        bob_address,
        alice_store,
        bob_pre_key_bundle,
    )

    lost_message_for_bob = session_cipher.message_encrypt(
        alice_store, bob_address, b"it was so long ago")

    for _ in range(15):
        alice_pre_key_bundle = create_pre_key_bundle(alice_store)
        bob_pre_key_bundle = create_pre_key_bundle(bob_store)

        session.process_prekey_bundle(
            bob_address,
            alice_store,
            bob_pre_key_bundle,
        )
        session.process_prekey_bundle(
            alice_address,
            bob_store,
            alice_pre_key_bundle,
        )

        message_for_bob = session_cipher.message_encrypt(
            alice_store, bob_address, b"hi bob")
        message_for_alice = session_cipher.message_encrypt(
            bob_store, alice_address, b"hi alice")

        assert message_for_bob.message_type(
        ) == 3  # 3 == CiphertextMessageType::PreKey
        assert (message_for_alice.message_type() == 3
                )  # 3 == CiphertextMessageType::PreKey

        assert not is_session_id_equal(alice_store, alice_address, bob_store,
                                       bob_address)

        alice_plaintext = session_cipher.message_decrypt(
            alice_store,
            bob_address,
            protocol.PreKeySignalMessage.try_from(
                message_for_alice.serialize()),
        )
        assert alice_plaintext == b"hi alice"

        bob_plaintext = session_cipher.message_decrypt(
            bob_store,
            alice_address,
            protocol.PreKeySignalMessage.try_from(message_for_bob.serialize()),
        )
        assert bob_plaintext == b"hi bob"

        assert alice_store.load_session(bob_address).session_version() == 3
        assert bob_store.load_session(alice_address).session_version() == 3

        assert not is_session_id_equal(alice_store, alice_address, bob_store,
                                       bob_address)

    for _ in range(50):
        message_for_bob = session_cipher.message_encrypt(
            alice_store, bob_address, b"hi bob")
        message_for_alice = session_cipher.message_encrypt(
            bob_store, alice_address, b"hi alice")

        assert (message_for_bob.message_type() == 2
                )  # 2 == CiphertextMessageType::Whisper
        assert (message_for_alice.message_type() == 2
                )  # 2 == CiphertextMessageType::Whisper

        assert not is_session_id_equal(alice_store, alice_address, bob_store,
                                       bob_address)

        alice_plaintext = session_cipher.message_decrypt(
            alice_store,
            bob_address,
            protocol.SignalMessage.try_from(message_for_alice.serialize()),
        )
        assert alice_plaintext == b"hi alice"

        bob_plaintext = session_cipher.message_decrypt(
            bob_store,
            alice_address,
            protocol.SignalMessage.try_from(message_for_bob.serialize()),
        )
        assert bob_plaintext == b"hi bob"

        assert alice_store.load_session(bob_address).session_version() == 3
        assert bob_store.load_session(alice_address).session_version() == 3

        assert not is_session_id_equal(alice_store, alice_address, bob_store,
                                       bob_address)

    alice_response = session_cipher.message_encrypt(alice_store, bob_address,
                                                    b"nice to see you")

    assert alice_response.message_type(
    ) == 2  # 2 == CiphertextMessageType::Whisper

    assert not is_session_id_equal(alice_store, alice_address, bob_store,
                                   bob_address)

    bob_response = session_cipher.message_encrypt(bob_store, alice_address,
                                                  b"you as well")
    assert bob_response.message_type(
    ) == 2  # CiphertextMessageType::Whisper => 2

    response_plaintext = session_cipher.message_decrypt(
        alice_store,
        bob_address,
        protocol.SignalMessage.try_from(bob_response.serialize()),
    )
    assert response_plaintext == b"you as well"
    assert is_session_id_equal(alice_store, alice_address, bob_store,
                               bob_address)

    blast_from_the_past = session_cipher.message_decrypt(
        bob_store,
        alice_address,
        protocol.PreKeySignalMessage.try_from(
            lost_message_for_bob.serialize()),
    )
    assert blast_from_the_past == b"it was so long ago"

    assert not is_session_id_equal(alice_store, alice_address, bob_store,
                                   bob_address)

    bob_response = session_cipher.message_encrypt(bob_store, alice_address,
                                                  b"so it was")
    assert bob_response.message_type(
    ) == 2  # CiphertextMessageType::Whisper => 2

    response_plaintext = session_cipher.message_decrypt(
        alice_store,
        bob_address,
        protocol.SignalMessage.try_from(bob_response.serialize()),
    )
    assert response_plaintext == b"so it was"
    assert is_session_id_equal(alice_store, alice_address, bob_store,
                               bob_address)
예제 #5
0
def test_optional_one_time_prekey():
    alice_address = address.ProtocolAddress("+14151111111", DEVICE_ID)
    bob_address = address.ProtocolAddress("+14151111112", DEVICE_ID)

    alice_identity_key_pair = identity_key.IdentityKeyPair.generate()
    bob_identity_key_pair = identity_key.IdentityKeyPair.generate()

    alice_registration_id = 1  # TODO: generate these
    bob_registration_id = 2

    alice_store = storage.InMemSignalProtocolStore(alice_identity_key_pair,
                                                   alice_registration_id)
    bob_store = storage.InMemSignalProtocolStore(bob_identity_key_pair,
                                                 bob_registration_id)

    bob_signed_pre_key_pair = curve.KeyPair.generate()
    bob_signed_pre_key_public = bob_signed_pre_key_pair.public_key().serialize(
    )
    bob_signed_pre_key_signature = (bob_store.get_identity_key_pair(
    ).private_key().calculate_signature(bob_signed_pre_key_public))

    signed_pre_key_id = 22

    bob_pre_key_bundle = state.PreKeyBundle(
        bob_store.get_local_registration_id(),
        DEVICE_ID,
        None,  # No prekey
        None,  # No prekey
        signed_pre_key_id,
        bob_signed_pre_key_pair.public_key(),
        bob_signed_pre_key_signature,
        bob_store.get_identity_key_pair().identity_key(),
    )

    session.process_prekey_bundle(
        bob_address,
        alice_store,
        bob_pre_key_bundle,
    )

    assert alice_store.load_session(bob_address).session_version() == 3

    original_message = b"Hobgoblins hold themselves to high standards of military honor"

    outgoing_message = session_cipher.message_encrypt(alice_store, bob_address,
                                                      original_message)
    outgoing_message.message_type() == 3  # 3 == CiphertextMessageType::PreKey

    incoming_message = protocol.PreKeySignalMessage.try_from(
        outgoing_message.serialize())

    signed_prekey = state.SignedPreKeyRecord(
        signed_pre_key_id,
        42,
        bob_signed_pre_key_pair,
        bob_signed_pre_key_signature,
    )
    bob_store.save_signed_pre_key(signed_pre_key_id, signed_prekey)

    plaintext = session_cipher.message_decrypt(bob_store, alice_address,
                                               incoming_message)
    assert original_message == plaintext
예제 #6
0
def test_bad_message_bundle():
    alice_address = address.ProtocolAddress("+14151111111", DEVICE_ID)
    bob_address = address.ProtocolAddress("+14151111112", DEVICE_ID)

    alice_identity_key_pair = identity_key.IdentityKeyPair.generate()
    bob_identity_key_pair = identity_key.IdentityKeyPair.generate()

    alice_registration_id = 1  # TODO: generate these
    bob_registration_id = 2

    alice_store = storage.InMemSignalProtocolStore(alice_identity_key_pair,
                                                   alice_registration_id)
    bob_store = storage.InMemSignalProtocolStore(bob_identity_key_pair,
                                                 bob_registration_id)

    bob_pre_key_pair = curve.KeyPair.generate()
    bob_signed_pre_key_pair = curve.KeyPair.generate()

    bob_signed_pre_key_public = bob_signed_pre_key_pair.public_key().serialize(
    )

    bob_signed_pre_key_signature = (bob_store.get_identity_key_pair(
    ).private_key().calculate_signature(bob_signed_pre_key_public))

    pre_key_id = 31337
    signed_pre_key_id = 22

    bob_pre_key_bundle = state.PreKeyBundle(
        bob_store.get_local_registration_id(),
        DEVICE_ID,
        pre_key_id,
        bob_pre_key_pair.public_key(),
        signed_pre_key_id,
        bob_signed_pre_key_pair.public_key(),
        bob_signed_pre_key_signature,
        bob_store.get_identity_key_pair().identity_key(),
    )

    session.process_prekey_bundle(
        bob_address,
        alice_store,
        bob_pre_key_bundle,
    )

    bob_prekey = state.PreKeyRecord(pre_key_id, bob_pre_key_pair)
    bob_store.save_pre_key(pre_key_id, bob_prekey)

    signed_prekey = state.SignedPreKeyRecord(
        signed_pre_key_id,
        42,
        bob_signed_pre_key_pair,
        bob_signed_pre_key_signature,
    )
    bob_store.save_signed_pre_key(signed_pre_key_id, signed_prekey)

    assert alice_store.load_session(bob_address)
    assert alice_store.load_session(bob_address).session_version() == 3

    original_message = b"Hobgoblins hold themselves to high standards of military honor"

    assert bob_store.get_pre_key(pre_key_id)

    outgoing_message = session_cipher.message_encrypt(alice_store, bob_address,
                                                      original_message)
    outgoing_message.message_type() == 3  # 3 == CiphertextMessageType::PreKey
    outgoing_message_wire = outgoing_message.serialize()

    edit_point = len(outgoing_message_wire) - 10
    corrupted_message = (outgoing_message_wire[:edit_point] +
                         bytes([outgoing_message_wire[edit_point] ^ 0x01]) +
                         outgoing_message_wire[edit_point + 1:])

    incoming_message = protocol.PreKeySignalMessage.try_from(corrupted_message)

    # This incoming message is corrupted, so we expect an exception to be raised
    with pytest.raises(SignalProtocolException):
        session_cipher.message_decrypt(bob_store, alice_address,
                                       incoming_message)

    assert bob_store.get_pre_key(pre_key_id)

    incoming_message = protocol.PreKeySignalMessage.try_from(
        outgoing_message_wire)

    plaintext = session_cipher.message_decrypt(bob_store, alice_address,
                                               incoming_message)

    assert original_message == plaintext

    # Trying to get the prekey will now fail, as the prekey has been used and removed from the store
    with pytest.raises(SignalProtocolException,
                       match="invalid prekey identifier"):
        assert bob_store.get_pre_key(pre_key_id)
예제 #7
0
def test_repeat_bundle_message_v3():
    alice_address = address.ProtocolAddress("+14151111111", DEVICE_ID)
    bob_address = address.ProtocolAddress("+14151111112", DEVICE_ID)

    alice_identity_key_pair = identity_key.IdentityKeyPair.generate()
    bob_identity_key_pair = identity_key.IdentityKeyPair.generate()

    alice_registration_id = 1  # TODO: generate these
    bob_registration_id = 2

    alice_store = storage.InMemSignalProtocolStore(alice_identity_key_pair,
                                                   alice_registration_id)
    bob_store = storage.InMemSignalProtocolStore(bob_identity_key_pair,
                                                 bob_registration_id)

    bob_pre_key_pair = curve.KeyPair.generate()
    bob_signed_pre_key_pair = curve.KeyPair.generate()

    bob_signed_pre_key_public = bob_signed_pre_key_pair.public_key().serialize(
    )

    bob_signed_pre_key_signature = (bob_store.get_identity_key_pair(
    ).private_key().calculate_signature(bob_signed_pre_key_public))

    pre_key_id = 31337
    signed_pre_key_id = 22

    bob_pre_key_bundle = state.PreKeyBundle(
        bob_store.get_local_registration_id(),
        DEVICE_ID,
        pre_key_id,
        bob_pre_key_pair.public_key(),
        signed_pre_key_id,
        bob_signed_pre_key_pair.public_key(),
        bob_signed_pre_key_signature,
        bob_store.get_identity_key_pair().identity_key(),
    )

    session.process_prekey_bundle(
        bob_address,
        alice_store,
        bob_pre_key_bundle,
    )

    assert alice_store.load_session(bob_address)
    assert alice_store.load_session(bob_address).session_version() == 3

    original_message = b"Hobgoblins hold themselves to high standards of military honor"

    outgoing_message1 = session_cipher.message_encrypt(alice_store,
                                                       bob_address,
                                                       original_message)
    outgoing_message2 = session_cipher.message_encrypt(alice_store,
                                                       bob_address,
                                                       original_message)
    outgoing_message1.message_type() == 3  # 3 == CiphertextMessageType::PreKey
    outgoing_message2.message_type() == 3  # 3 == CiphertextMessageType::PreKey

    incoming_message = protocol.PreKeySignalMessage.try_from(
        outgoing_message1.serialize())

    bob_prekey = state.PreKeyRecord(pre_key_id, bob_pre_key_pair)
    bob_store.save_pre_key(pre_key_id, bob_prekey)

    signed_prekey = state.SignedPreKeyRecord(
        signed_pre_key_id,
        42,
        bob_signed_pre_key_pair,
        bob_signed_pre_key_signature,
    )
    bob_store.save_signed_pre_key(signed_pre_key_id, signed_prekey)

    ptext = session_cipher.message_decrypt(bob_store, alice_address,
                                           incoming_message)
    assert original_message == ptext

    bob_outgoing = session_cipher.message_encrypt(bob_store, alice_address,
                                                  original_message)
    assert bob_outgoing.message_type(
    ) == 2  # 2 == CiphertextMessageType::Whisper

    alice_decrypts = session_cipher.message_decrypt(alice_store, bob_address,
                                                    bob_outgoing)
    assert alice_decrypts == original_message

    # Verify the second message can be processed

    incoming_message2 = protocol.PreKeySignalMessage.try_from(
        outgoing_message2.serialize())
    ptext = session_cipher.message_decrypt(bob_store, alice_address,
                                           incoming_message2)
    assert original_message == ptext

    bob_outgoing = session_cipher.message_encrypt(bob_store, alice_address,
                                                  original_message)
    alice_decrypts = session_cipher.message_decrypt(alice_store, bob_address,
                                                    bob_outgoing)
    assert alice_decrypts == original_message
예제 #8
0
def test_basic_prekey_v3():
    alice_address = address.ProtocolAddress("+14151111111", DEVICE_ID)
    bob_address = address.ProtocolAddress("+14151111112", DEVICE_ID)

    alice_identity_key_pair = identity_key.IdentityKeyPair.generate()
    bob_identity_key_pair = identity_key.IdentityKeyPair.generate()

    alice_registration_id = 1  # TODO: generate these
    bob_registration_id = 2

    alice_store = storage.InMemSignalProtocolStore(alice_identity_key_pair,
                                                   alice_registration_id)
    bob_store = storage.InMemSignalProtocolStore(bob_identity_key_pair,
                                                 bob_registration_id)

    bob_pre_key_pair = curve.KeyPair.generate()
    bob_signed_pre_key_pair = curve.KeyPair.generate()

    bob_signed_pre_key_public = bob_signed_pre_key_pair.public_key().serialize(
    )

    bob_signed_pre_key_signature = (bob_store.get_identity_key_pair(
    ).private_key().calculate_signature(bob_signed_pre_key_public))

    pre_key_id = 31337
    signed_pre_key_id = 22

    bob_pre_key_bundle = state.PreKeyBundle(
        bob_store.get_local_registration_id(),
        DEVICE_ID,
        pre_key_id,
        bob_pre_key_pair.public_key(),
        signed_pre_key_id,
        bob_signed_pre_key_pair.public_key(),
        bob_signed_pre_key_signature,
        bob_store.get_identity_key_pair().identity_key(),
    )

    assert alice_store.load_session(bob_address) is None

    # Below standalone function would make more sense as a method on alice_store?
    session.process_prekey_bundle(
        bob_address,
        alice_store,
        bob_pre_key_bundle,
    )

    assert alice_store.load_session(bob_address)
    assert alice_store.load_session(bob_address).session_version() == 3

    original_message = b"Hobgoblins hold themselves to high standards of military honor"

    outgoing_message = session_cipher.message_encrypt(alice_store, bob_address,
                                                      original_message)
    outgoing_message.message_type() == 3  # 3 == CiphertextMessageType::PreKey
    outgoing_message_wire = outgoing_message.serialize()

    # Now over to fake Bob for processing the first message

    incoming_message = protocol.PreKeySignalMessage.try_from(
        outgoing_message_wire)

    bob_prekey = state.PreKeyRecord(pre_key_id, bob_pre_key_pair)
    bob_store.save_pre_key(pre_key_id, bob_prekey)

    signed_prekey = state.SignedPreKeyRecord(
        signed_pre_key_id,
        42,
        bob_signed_pre_key_pair,
        bob_signed_pre_key_signature,
    )

    bob_store.save_signed_pre_key(signed_pre_key_id, signed_prekey)

    assert bob_store.load_session(alice_address) is None

    plaintext = session_cipher.message_decrypt(bob_store, alice_address,
                                               incoming_message)

    assert original_message == plaintext

    bobs_response = b"Who watches the watchers?"

    assert bob_store.load_session(alice_address)

    bobs_session_with_alice = bob_store.load_session(alice_address)
    assert bobs_session_with_alice.session_version() == 3
    assert len(bobs_session_with_alice.alice_base_key()) == 32 + 1

    bob_outgoing = session_cipher.message_encrypt(bob_store, alice_address,
                                                  bobs_response)
    assert bob_outgoing.message_type(
    ) == 2  # 2 == CiphertextMessageType::Whisper

    # Now back to fake alice

    alice_decrypts = session_cipher.message_decrypt(alice_store, bob_address,
                                                    bob_outgoing)
    assert alice_decrypts == bobs_response

    run_interaction(alice_store, alice_address, bob_store, bob_address)

    alice_identity_key_pair = identity_key.IdentityKeyPair.generate()
    alice_registration_id = 1  # TODO: generate these
    alice_store = storage.InMemSignalProtocolStore(alice_identity_key_pair,
                                                   alice_registration_id)

    bob_pre_key_pair = curve.KeyPair.generate()
    bob_signed_pre_key_pair = curve.KeyPair.generate()
    bob_signed_pre_key_public = bob_signed_pre_key_pair.public_key().serialize(
    )

    bob_signed_pre_key_signature = (bob_store.get_identity_key_pair(
    ).private_key().calculate_signature(bob_signed_pre_key_public))

    pre_key_id = 31337
    signed_pre_key_id = 22

    bob_pre_key_bundle = state.PreKeyBundle(
        bob_store.get_local_registration_id(),
        DEVICE_ID,
        pre_key_id + 1,
        bob_pre_key_pair.public_key(),
        signed_pre_key_id + 1,
        bob_signed_pre_key_pair.public_key(),
        bob_signed_pre_key_signature,
        bob_store.get_identity_key_pair().identity_key(),
    )

    bob_prekey = state.PreKeyRecord(pre_key_id + 1, bob_pre_key_pair)
    bob_store.save_pre_key(pre_key_id + 1, bob_prekey)

    signed_prekey = state.SignedPreKeyRecord(
        signed_pre_key_id + 1,
        42,
        bob_signed_pre_key_pair,
        bob_signed_pre_key_signature,
    )
    bob_store.save_signed_pre_key(signed_pre_key_id + 1, signed_prekey)

    session.process_prekey_bundle(
        bob_address,
        alice_store,
        bob_pre_key_bundle,
    )

    outgoing_message = session_cipher.message_encrypt(alice_store, bob_address,
                                                      original_message)

    with pytest.raises(SignalProtocolException, match="untrusted identity"):
        session_cipher.message_decrypt(bob_store, alice_address,
                                       outgoing_message)

    assert bob_store.save_identity(
        alice_address,
        alice_store.get_identity_key_pair().identity_key())

    decrypted = session_cipher.message_decrypt(bob_store, alice_address,
                                               outgoing_message)
    assert decrypted == original_message

    # Sign pre-key with wrong key
    bob_pre_key_bundle = state.PreKeyBundle(
        bob_store.get_local_registration_id(),
        DEVICE_ID,
        pre_key_id,
        bob_pre_key_pair.public_key(),
        signed_pre_key_id,
        bob_signed_pre_key_pair.public_key(),
        bob_signed_pre_key_signature,
        alice_store.get_identity_key_pair().identity_key(),
    )

    with pytest.raises(SignalProtocolException):
        session.process_prekey_bundle(bob_address, alice_store,
                                      bob_pre_key_bundle)
예제 #9
0
def test_basic_large_message():
    alice_address = address.ProtocolAddress("+14151111111", DEVICE_ID)
    bob_address = address.ProtocolAddress("+14151111112", DEVICE_ID)

    alice_identity_key_pair = identity_key.IdentityKeyPair.generate()
    bob_identity_key_pair = identity_key.IdentityKeyPair.generate()

    alice_registration_id = 1  # TODO: generate these
    bob_registration_id = 2

    alice_store = storage.InMemSignalProtocolStore(alice_identity_key_pair,
                                                   alice_registration_id)
    bob_store = storage.InMemSignalProtocolStore(bob_identity_key_pair,
                                                 bob_registration_id)

    bob_pre_key_pair = curve.KeyPair.generate()
    bob_signed_pre_key_pair = curve.KeyPair.generate()

    bob_signed_pre_key_public = bob_signed_pre_key_pair.public_key().serialize(
    )

    bob_signed_pre_key_signature = (bob_store.get_identity_key_pair(
    ).private_key().calculate_signature(bob_signed_pre_key_public))

    pre_key_id = 31337
    signed_pre_key_id = 22

    bob_pre_key_bundle = state.PreKeyBundle(
        bob_store.get_local_registration_id(),
        DEVICE_ID,
        pre_key_id,
        bob_pre_key_pair.public_key(),
        signed_pre_key_id,
        bob_signed_pre_key_pair.public_key(),
        bob_signed_pre_key_signature,
        bob_store.get_identity_key_pair().identity_key(),
    )

    assert alice_store.load_session(bob_address) is None

    # Below standalone function would make more sense as a method on alice_store?
    session.process_prekey_bundle(
        bob_address,
        alice_store,
        bob_pre_key_bundle,
    )

    assert alice_store.load_session(bob_address)
    assert alice_store.load_session(bob_address).session_version() == 3

    original_message = bytes(1024 * 1000)  # 1 MB empty attachment

    outgoing_message = session_cipher.message_encrypt(alice_store, bob_address,
                                                      original_message)
    outgoing_message.message_type() == 3  # 3 == CiphertextMessageType::PreKey
    outgoing_message_wire = outgoing_message.serialize()

    incoming_message = protocol.PreKeySignalMessage.try_from(
        outgoing_message_wire)

    bob_prekey = state.PreKeyRecord(pre_key_id, bob_pre_key_pair)
    bob_store.save_pre_key(pre_key_id, bob_prekey)

    signed_prekey = state.SignedPreKeyRecord(
        signed_pre_key_id,
        42,
        bob_signed_pre_key_pair,
        bob_signed_pre_key_signature,
    )

    bob_store.save_signed_pre_key(signed_pre_key_id, signed_prekey)

    plaintext = session_cipher.message_decrypt(bob_store, alice_address,
                                               incoming_message)

    assert original_message == plaintext