def test_next_message_return_authentication_required_message_when_connection_step_4_and_role_is_server_with_password( ): # Given password_to_derive = b"test" password_salt = os.urandom(16) derived_password = derive_password_scrypt( password_salt=password_salt, password_to_derive=password_to_derive) allowed_authentication_method = ["password"] authentication_information_server = { "password": { Handshake.PASSWORD_AUTH_METHOD_DERIVED_PASSWORD_KEY: derived_password, Handshake.PASSWORD_AUTH_METHOD_SALT_KEY: password_salt } } server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_method, authentication_information=authentication_information_server) client = Handshake( role=Handshake.CLIENT, allowed_authentication_methods=allowed_authentication_method) expected_message = UDPMessage( code=codes.HANDSHAKE, topic=Handshake.AUTHENTICATION_REQUIRED_TOPIC) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) # When result = server.next_message() # Then assert result.msg_id == expected_message.msg_id assert result.topic == expected_message.topic
def test_encrypt_return_different_bytes_than_input_when_connection_step_4_for_both_roles( ): # Given msg_to_encrypt = b"A very secret message" password_salt = os.urandom(16) password_to_derive = b"test" derived_password = derive_password_scrypt( password_salt=password_salt, password_to_derive=password_to_derive) allowed_authentication_method = ["password"] authentication_information_server = { "password": { Handshake.PASSWORD_AUTH_METHOD_DERIVED_PASSWORD_KEY: derived_password, Handshake.PASSWORD_AUTH_METHOD_SALT_KEY: password_salt } } server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_method, authentication_information=authentication_information_server) client = Handshake( role=Handshake.CLIENT, allowed_authentication_methods=allowed_authentication_method) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) # When encrypt_server = server._encrypt(data=msg_to_encrypt) encrypt_client = client._encrypt(data=msg_to_encrypt) # Then assert encrypt_client != msg_to_encrypt assert encrypt_server != msg_to_encrypt
def test_client_status_is_failed_when_authentication_method_is_custom_and_disapprove_is_called( ): # Given expected_status = Handshake.CONNECTION_STATUS_FAILED allowed_authentication_methods_client = ["custom"] allowed_authentication_methods_server = ["custom"] authentication_information_client = {} client = Handshake( role=Handshake.CLIENT, authentication_information=authentication_information_client, allowed_authentication_methods=allowed_authentication_methods_client) server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_methods_server) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) # When server.disapprove() status = server.get_status() # Then assert status == expected_status
def test_get_authentication_information_return_given_authentication_information_when_custom_method_is_used( ): # Given allowed_authentication_methods_client = ["custom"] allowed_authentication_methods_server = ["custom"] authentication_information_client = {"test": "test"} client = Handshake( role=Handshake.CLIENT, authentication_information=authentication_information_client, allowed_authentication_methods=allowed_authentication_methods_client) server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_methods_server) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) # When info = server.get_authentication_information() # Then assert info == authentication_information_client
def test_authentication_is_required_if_password_is_provided_as_authentication_method( ): # Given allowed_authentication_methods = ["password"] password_salt = os.urandom(16) password_to_derive = b"test" derived_password = derive_password_scrypt( password_salt=password_salt, password_to_derive=password_to_derive) authentication_information_server = { "password": { Handshake.PASSWORD_AUTH_METHOD_DERIVED_PASSWORD_KEY: derived_password, Handshake.PASSWORD_AUTH_METHOD_SALT_KEY: password_salt } } server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_methods, authentication_information=authentication_information_server) client = Handshake( role=Handshake.CLIENT, allowed_authentication_methods=allowed_authentication_methods) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) # When result = server.next_message() # Then assert int.from_bytes(result.msg_id, 'little') == codes.HANDSHAKE assert int.from_bytes(result.topic, 'little') == Handshake.AUTHENTICATION_REQUIRED_TOPIC
def test_client_status_is_waiting_approval_when_authentication_method_is_custom( ): # Given expected_status = Handshake.CONNECTION_STATUS_WAIT_APPROVAL allowed_authentication_methods_client = ["custom"] allowed_authentication_methods_server = ["custom"] authentication_information_client = {} client = Handshake( role=Handshake.CLIENT, authentication_information=authentication_information_client, allowed_authentication_methods=allowed_authentication_methods_client) server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_methods_server) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) # When status = server.get_status() # Then assert status == expected_status
def test_authentication_message_select_custom_method_if_it_is_the_only_authentication_method_for_both_instances( ): # Given allowed_authentication_methods = ["custom"] server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_methods) authentication_information_client = {} client = Handshake( role=Handshake.CLIENT, authentication_information=authentication_information_client, allowed_authentication_methods=allowed_authentication_methods) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) client.add_message(server.next_message()) authentication_message = client.next_message() # When payload = client._decrypt(authentication_message.payload) result = json.loads(bytes.decode(payload, "utf8")) # Then assert result[ Handshake.SELECTED_AUTHENTICATION_METHOD_KEY_NAME] == "custom"
def test_allowed_authentication_methods_default_value_is_no_authentication(): # Given server = Handshake(role=Handshake.SERVER) client = Handshake(role=Handshake.CLIENT) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) # When result = server.next_message() # Then assert int.from_bytes(result.msg_id, 'little') == codes.HANDSHAKE assert int.from_bytes(result.topic, 'little') == Handshake.CONNECTION_APPROVED_TOPIC
def test_next_message_return_authentication_message_when_connection_step_4_and_role_is_client_with_password( ): # Given password_to_derive = b"test" password_salt = os.urandom(16) allowed_authentication_method = ["password"] derived_password = derive_password_scrypt( password_salt=password_salt, password_to_derive=password_to_derive) authentication_information_client = { Handshake.PASSWORD_AUTH_METHOD_PASSWORD_KEY: password_to_derive.decode("utf8") } authentication_information_server = { "password": { Handshake.PASSWORD_AUTH_METHOD_DERIVED_PASSWORD_KEY: derived_password, Handshake.PASSWORD_AUTH_METHOD_SALT_KEY: password_salt } } server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_method, authentication_information=authentication_information_server) client = Handshake( role=Handshake.CLIENT, allowed_authentication_methods=allowed_authentication_method, authentication_information=authentication_information_client) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) client.add_message(server.next_message()) expected_id = codes.HANDSHAKE expected_topic = Handshake.AUTHENTICATION_TOPIC # When result = client.next_message() payload = client._decrypt(result.payload) payload = json.loads(bytes.decode(payload, "utf8")) password = payload[Handshake.AUTH_METHOD_INFO_KEY]["password"] # Then assert int.from_bytes(result.msg_id, 'little') == expected_id assert int.from_bytes(result.topic, 'little') == expected_topic assert password == password_to_derive.decode('utf8')
def test_authentication_message_contain_random_bits_of_correct_length(): # Given password_salt = os.urandom(16) password_to_derive = b"test" derived_password = derive_password_scrypt( password_salt=password_salt, password_to_derive=password_to_derive) allowed_authentication_method = ["password"] authentication_information_client = { Handshake.PASSWORD_AUTH_METHOD_PASSWORD_KEY: password_to_derive.decode("utf8") } authentication_information_server = { "password": { Handshake.PASSWORD_AUTH_METHOD_DERIVED_PASSWORD_KEY: derived_password, Handshake.PASSWORD_AUTH_METHOD_SALT_KEY: password_salt } } server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_method, authentication_information=authentication_information_server) client = Handshake( role=Handshake.CLIENT, allowed_authentication_methods=allowed_authentication_method, authentication_information=authentication_information_client) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) client.add_message(server.next_message()) expected_id = codes.HANDSHAKE expected_topic = Handshake.AUTHENTICATION_TOPIC # When result = client.next_message() payload = client._decrypt(result.payload) payload = json.loads(bytes.decode(payload, "utf8")) # Then assert int.from_bytes(result.msg_id, 'little') == expected_id assert int.from_bytes(result.topic, 'little') == expected_topic assert Handshake.AUTHENTICATION_RANDOM_BITS_KEY in payload.keys() assert len( base64.b64decode(payload[Handshake.AUTHENTICATION_RANDOM_BITS_KEY]) ) == Handshake.RANDOM_BITS_LENGTH
def test_both_server_and_client_can_generate_shared_key_when_peer_public_key_has_been_received( ): # Given server = Handshake(role=Handshake.SERVER) client = Handshake(role=Handshake.CLIENT) expected_shared_key = server._private_key.exchange( ec.ECDH(), client._private_key.public_key()) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) # When server_secret = server.get_shared_key() client_secret = client.get_shared_key() # Then assert expected_shared_key == server_secret == client_secret
def test_server_key_share_message_inform_selected_protocol_is_1_dot_0_if_clients_allowed_protocol_not_sorted( ): # Given allowed_protocol_versions = ['1.0', 'alpha'] expected_result = "1.0" server = Handshake(role=Handshake.SERVER) client = Handshake(role=Handshake.CLIENT, allowed_protocol_versions=allowed_protocol_versions) server.add_message(client.next_message()) server_key_share_message = server.next_message() # When result = json.loads(bytes.decode(server_key_share_message.payload, "utf8")) # Then assert result[ Handshake.SELECTED_PROTOCOL_VERSION_KEY_NAME] == expected_result
def test_server_key_share_message_inform_selected_protocol_is_alpha_if_it_is_the_only_available_for_server( ): # Given allowed_protocol_versions = ['alpha'] expected_result = "alpha" server = Handshake(role=Handshake.SERVER, allowed_protocol_versions=allowed_protocol_versions) client = Handshake(role=Handshake.CLIENT) server.add_message(client.next_message()) server_key_share_message = server.next_message() # When result = json.loads(bytes.decode(server_key_share_message.payload, "utf8")) # Then assert result[ Handshake.SELECTED_PROTOCOL_VERSION_KEY_NAME] == expected_result
def test_get_status_return_complete_when_and_handshake_was_successful_without_authentication( ): # Given server = Handshake(role=Handshake.SERVER) client = Handshake(role=Handshake.CLIENT) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) client.add_message(server.next_message()) # When result_client = client.get_status() result_server = server.get_status() # Then assert result_client == Handshake.CONNECTION_STATUS_APPROVED assert result_server == Handshake.CONNECTION_STATUS_APPROVED
def test_server_key_share_message_contain_selected_protocol_version_which_is_the_latest_available( ): # Given server = Handshake(role=Handshake.SERVER) client = Handshake(role=Handshake.CLIENT) server.add_message(client.next_message()) server_key_share_message = server.next_message() # When result = json.loads(bytes.decode(server_key_share_message.payload, "utf8")) # Then assert Handshake.SELECTED_PROTOCOL_VERSION_KEY_NAME in result.keys() assert result[ Handshake. SELECTED_PROTOCOL_VERSION_KEY_NAME] == Handshake.PROTOCOL_VERSIONS_AVAILABLE[ -1]
def test_authentication_required_message_contain_a_list_of_authentication_methods_available( ): # Given password_to_derive = b"test" password_salt = os.urandom(16) derived_password = derive_password_scrypt( password_salt=password_salt, password_to_derive=password_to_derive) allowed_authentication_method = Handshake.AUTHENTICATION_METHODS_AVAILABLE authentication_information_server = { "password": { Handshake.PASSWORD_AUTH_METHOD_DERIVED_PASSWORD_KEY: derived_password, Handshake.PASSWORD_AUTH_METHOD_SALT_KEY: password_salt } } server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_method, authentication_information=authentication_information_server) authentication_information_client = { Handshake.PASSWORD_AUTH_METHOD_PASSWORD_KEY: password_to_derive } client = Handshake( role=Handshake.CLIENT, allowed_authentication_methods=["password"], authentication_information=authentication_information_client) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) authentication_required_message = server.next_message() # When result = json.loads( bytes.decode(authentication_required_message.payload, "utf8")) # Then assert Handshake.AUTHENTICATION_METHODS_AVAILABLE_KEY_NAME in result.keys() assert type( result[Handshake.AUTHENTICATION_METHODS_AVAILABLE_KEY_NAME]) is list assert len(result[Handshake.AUTHENTICATION_METHODS_AVAILABLE_KEY_NAME]) > 0 assert result[ Handshake. AUTHENTICATION_METHODS_AVAILABLE_KEY_NAME] == Handshake.AUTHENTICATION_METHODS_AVAILABLE
def test_connection_fail_if_abort_is_called_on_client_after_a_connection_request( ): # Given server = Handshake(role=Handshake.SERVER) client = Handshake(role=Handshake.CLIENT) server.add_message(client.next_message()) client.add_message(server.next_message()) # When client.abort() connection_failed_message = client.next_message() server.add_message(connection_failed_message) # Then assert server.get_status() == Handshake.CONNECTION_STATUS_FAILED assert client.get_status() == Handshake.CONNECTION_STATUS_FAILED assert int.from_bytes(connection_failed_message.topic, 'little') == Handshake.CONNECTION_FAILED_TOPIC
def test_authentication_message_select_password_method_if_it_is_the_only_authentication_method_for_both_instances( ): # Given password_salt = os.urandom(16) password_to_derive = b"test_password" allowed_authentication_methods = ["password"] derived_password = derive_password_scrypt( password_salt=password_salt, password_to_derive=password_to_derive) authentication_information_client = { Handshake.PASSWORD_AUTH_METHOD_PASSWORD_KEY: password_to_derive.decode("utf8") } authentication_information_server = { "password": { Handshake.PASSWORD_AUTH_METHOD_DERIVED_PASSWORD_KEY: derived_password, Handshake.PASSWORD_AUTH_METHOD_SALT_KEY: password_salt } } server = Handshake( role=Handshake.SERVER, authentication_information=authentication_information_server, allowed_authentication_methods=allowed_authentication_methods) client = Handshake( role=Handshake.CLIENT, authentication_information=authentication_information_client, allowed_authentication_methods=allowed_authentication_methods) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) client.add_message(server.next_message()) authentication_message = client.next_message() # When payload = client._decrypt(authentication_message.payload) result = json.loads(bytes.decode(payload, "utf8")) # Then assert Handshake.SELECTED_AUTHENTICATION_METHOD_KEY_NAME in result.keys() assert result[ Handshake.SELECTED_AUTHENTICATION_METHOD_KEY_NAME] == "password"
def test_next_message_return_a_message_with_an_ec_public_key_when_connection_step_2_and_role_is_server( ): # Given server = Handshake(role=Handshake.SERVER) client = Handshake(role=Handshake.CLIENT) expected_result_pub_key = bytes.decode( server._private_key.public_key().public_bytes( encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo), "ascii") server.add_message(client.next_message()) # When result = server.next_message() result_payload = json.loads(bytes.decode(result.payload, "utf8")) # Then assert int.from_bytes(result.msg_id, 'little') == codes.HANDSHAKE assert int.from_bytes(result.topic, 'little') == Handshake.SERVER_KEY_SHARE_TOPIC assert result_payload[ Handshake.SERVER_PUBLIC_KEY_KEY_NAME] == expected_result_pub_key
def test_next_message_returns_none_when_no_connection_request_and_cm_is_server( ): # Given role = Handshake.SERVER server = Handshake(role=role) expected_message = None # When result = server.next_message() # Then assert result == expected_message
def test_no_authentication_is_required_when_no_allowed_authentication_method_provided( ): # Given allowed_authentication_method = [] server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_method) client = Handshake( role=Handshake.CLIENT, allowed_authentication_methods=allowed_authentication_method) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) # When result = server.next_message() # Then assert int.from_bytes(result.msg_id, 'little') == codes.HANDSHAKE assert int.from_bytes(result.topic, 'little') == Handshake.CONNECTION_APPROVED_TOPIC
def test_client_next_message_is_connection_failed_if_no_common_authentication_method_auth_request( ): # Given allowed_authentication_methods_server = ["password"] allowed_authentication_methods_client = ["custom"] password_salt = os.urandom(16) password_to_derive = b"test_password" derived_password = derive_password_scrypt( password_salt=password_salt, password_to_derive=password_to_derive) authentication_information_client = { Handshake.PASSWORD_AUTH_METHOD_PASSWORD_KEY: password_to_derive } client = Handshake( role=Handshake.CLIENT, authentication_information=authentication_information_client, allowed_authentication_methods=allowed_authentication_methods_client) authentication_information_server = { "password": { Handshake.PASSWORD_AUTH_METHOD_DERIVED_PASSWORD_KEY: derived_password, Handshake.PASSWORD_AUTH_METHOD_SALT_KEY: password_salt } } server = Handshake( role=Handshake.SERVER, authentication_information=authentication_information_server, allowed_authentication_methods=allowed_authentication_methods_server) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) client.add_message(server.next_message()) # When connection_failed_message = client.next_message() # Then assert int.from_bytes(connection_failed_message.topic, 'little') == Handshake.CONNECTION_FAILED_TOPIC
def test_next_message_returns_correct_message_when_connection_request_begins_and_role_is_client( ): # Given role = Handshake.CLIENT client = Handshake(role=role) expected_id = codes.HANDSHAKE expected_topic = Handshake.CONNECTION_REQUEST_TOPIC # When result = client.next_message() # Then assert int.from_bytes(result.msg_id, 'little') == expected_id assert int.from_bytes(result.topic, 'little') == expected_topic
def test_connection_fail_if_server_and_client_have_not_a_common_protocol_version( ): # Given allowed_protocol_versions_client = ['alpha'] allowed_protocol_versions_server = ['1.0'] server = Handshake( role=Handshake.SERVER, allowed_protocol_versions=allowed_protocol_versions_server) client = Handshake( role=Handshake.CLIENT, allowed_protocol_versions=allowed_protocol_versions_client) server.add_message(client.next_message()) # When connection_failed_message = server.next_message() client.add_message(connection_failed_message) # Then assert server.get_status() == Handshake.CONNECTION_STATUS_FAILED assert client.get_status() == Handshake.CONNECTION_STATUS_FAILED assert int.from_bytes(connection_failed_message.topic, 'little') == Handshake.CONNECTION_FAILED_TOPIC
def test_next_message_return_connection_approved_message_when_connection_step_4_and_role_is_server_and_no_password( ): # Given allowed_authentication_method = [] server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_method) client = Handshake( role=Handshake.CLIENT, allowed_authentication_methods=allowed_authentication_method) expected_message = UDPMessage(code=codes.HANDSHAKE, topic=Handshake.CONNECTION_APPROVED_TOPIC) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) # When result = server.next_message() # Then assert result.msg_id == expected_message.msg_id assert result.topic == expected_message.topic
def test_connection_request_message_contains_a_list_of_available_protocols(): # Given role = Handshake.CLIENT client = Handshake(role=role) expected_message = UDPMessage(code=codes.HANDSHAKE, topic=Handshake.CONNECTION_REQUEST_TOPIC) connection_request_message = client.next_message() # When result = json.loads( bytes.decode(connection_request_message.payload, "utf8")) # Then assert Handshake.PROTOCOL_VERSIONS_AVAILABLE_KEY_NAME in result.keys() assert type(result[Handshake.PROTOCOL_VERSIONS_AVAILABLE_KEY_NAME]) is list assert len(result[Handshake.PROTOCOL_VERSIONS_AVAILABLE_KEY_NAME]) > 0 assert result[Handshake.PROTOCOL_VERSIONS_AVAILABLE_KEY_NAME][0] == "alpha" assert result[ Handshake. PROTOCOL_VERSIONS_AVAILABLE_KEY_NAME] == Handshake.PROTOCOL_VERSIONS_AVAILABLE
def test_next_message_return_connection_failed_msg_when_connection_step_6_and_role_is_client_and_password_incorrect( ): # Given password_salt = os.urandom(16) password_to_derive = b"test" password_client = "incorrect" derived_password = derive_password_scrypt( password_salt=password_salt, password_to_derive=password_to_derive) allowed_authentication_method = ["password"] authentication_information_client = { Handshake.PASSWORD_AUTH_METHOD_PASSWORD_KEY: password_client } authentication_information_server = { "password": { Handshake.PASSWORD_AUTH_METHOD_DERIVED_PASSWORD_KEY: derived_password, Handshake.PASSWORD_AUTH_METHOD_SALT_KEY: password_salt } } server = Handshake( role=Handshake.SERVER, allowed_authentication_methods=allowed_authentication_method, authentication_information=authentication_information_server) client = Handshake( role=Handshake.CLIENT, allowed_authentication_methods=allowed_authentication_method, authentication_information=authentication_information_client) expected_id = codes.HANDSHAKE expected_topic = Handshake.CONNECTION_FAILED_TOPIC server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) # When result = server.next_message() # Then assert int.from_bytes(result.msg_id, 'little') == expected_id assert int.from_bytes(result.topic, 'little') == expected_topic
def test_get_status_return_approved_when_authentication_is_correct(): # Given allowed_authentication_methods = ["password"] password_to_derive = b"test" password_salt = os.urandom(16) password_client = "test" derived_password = derive_password_scrypt( password_salt=password_salt, password_to_derive=password_to_derive) authentication_information_server = { "password": { Handshake.PASSWORD_AUTH_METHOD_DERIVED_PASSWORD_KEY: derived_password, Handshake.PASSWORD_AUTH_METHOD_SALT_KEY: password_salt } } server = Handshake( role=Handshake.SERVER, authentication_information=authentication_information_server, allowed_authentication_methods=allowed_authentication_methods) authentication_information_client = { Handshake.PASSWORD_AUTH_METHOD_PASSWORD_KEY: password_client } client = Handshake( role=Handshake.CLIENT, authentication_information=authentication_information_client, allowed_authentication_methods=allowed_authentication_methods) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) client.add_message(server.next_message()) server.add_message(client.next_message()) client.add_message(server.next_message()) # When result_client = client.get_status() result_server = server.get_status() # Then assert result_client == Handshake.CONNECTION_STATUS_APPROVED assert result_server == Handshake.CONNECTION_STATUS_APPROVED