def test_handlers_called(mock_server, mocker):
    ws_client, transcription_config, audio_settings = default_ws_client_setup(
        mock_server.url)

    handlers = {}
    for msg_type in ServerMessageType:
        mock = mocker.MagicMock()
        handlers[msg_type.name] = mock
        ws_client.add_event_handler(msg_type.name, mock)

    # Add a handler for all events to test that the 'all'
    # keyword works properly.
    all_handler = mocker.MagicMock()
    ws_client.add_event_handler("all", all_handler)

    with open(path_to_test_resource("ch.wav"), "rb") as audio_stream:
        ws_client.run_synchronously(audio_stream, transcription_config,
                                    audio_settings)
    mock_server.wait_for_clean_disconnects()

    # Each handler should have been called once for every message
    # received from the server
    server_message_counts = Counter(msg["message"]
                                    for msg in mock_server.messages_sent)
    for (msg_name, count) in server_message_counts.items():
        assert msg_name and handlers[msg_name].call_count == count

    # The 'all' handler should have been called for every message.
    assert all_handler.call_count == len(mock_server.messages_sent)
def test_main_with_basic_options(mock_server):
    args = [
        "-vv",
        "transcribe",
        "--ssl-mode=insecure",
        "--url",
        mock_server.url,
        path_to_test_resource("ch.wav"),
    ]
    cli.main(vars(cli.parse_args(args)))
    mock_server.wait_for_clean_disconnects()

    assert mock_server.clients_connected_count == 1
    assert mock_server.clients_disconnected_count == 1
    assert mock_server.messages_received
    assert mock_server.messages_sent
    assert mock_server.connection_request.path == "/v2"
def test_middlewares_called(mock_server, mocker):
    ws_client, transcription_config, audio_settings = default_ws_client_setup(
        mock_server.url)

    middlewares = {}
    for msg_type in ClientMessageType:
        mock = mocker.MagicMock()
        middlewares[msg_type.name] = mock
        ws_client.add_middleware(msg_type.name, mock)

    # Add a middleware for all events to test that the 'all'
    # keyword works properly.
    all_handler = mocker.MagicMock()
    ws_client.add_middleware("all", all_handler)

    # Add another middleware just for StartRecognition to test that we can
    # edit values in the outgoing messages via a middleware.
    # pylint: disable=unused-argument
    def language_changing_middleware(msg, is_binary):
        msg["transcription_config"]["language"] = "ja"

    ws_client.add_middleware(ClientMessageType.StartRecognition,
                             language_changing_middleware)

    with open(path_to_test_resource("ch.wav"), "rb") as audio_stream:
        ws_client.run_synchronously(audio_stream, transcription_config,
                                    audio_settings)
    mock_server.wait_for_clean_disconnects()

    # Each handler should have been called once for every message
    # sent from the client
    client_message_counts = Counter(
        msg["message"] if isinstance(msg, dict) else "AddAudio"
        for msg in mock_server.messages_received)
    for (msg_name, count) in client_message_counts.items():
        assert msg_name and middlewares[msg_name].call_count == count

    # The change to the language made by the middleware above
    # should have been received
    assert (mock_server.find_start_recognition_message()
            ["transcription_config"]["language"]  # noqa
            == "ja")
    # The 'all' handler should have been called for every message.
    assert all_handler.call_count == len(mock_server.messages_received)
def test_client_stops_when_asked_and_sends_end_of_stream(mock_server):
    ws_client, transcription_config, audio_settings = default_ws_client_setup(
        mock_server.url)

    num_messages_before_stop = 0

    def stopper(msg):  # pylint: disable=unused-argument
        nonlocal num_messages_before_stop
        num_messages_before_stop = len(mock_server.messages_received)
        ws_client.stop()

    ws_client.add_event_handler(ServerMessageType.RecognitionStarted, stopper)

    with open(path_to_test_resource("ch.wav"), "rb") as audio_stream:
        ws_client.run_synchronously(audio_stream, transcription_config,
                                    audio_settings)
    mock_server.wait_for_clean_disconnects()

    num_messages_after_stop = len(mock_server.messages_received)
    assert num_messages_before_stop + 1 == num_messages_after_stop
    assert mock_server.messages_received[-1]["message"] == "EndOfStream"
def test_update_transcription_config_sends_set_recognition_config(mock_server):
    ws_client, transcription_config, audio_settings = default_ws_client_setup(
        mock_server.url)

    def config_updater(msg):  # pylint: disable=unused-argument
        new_config = copy.deepcopy(transcription_config)
        new_config.language = "ja"
        ws_client.update_transcription_config(new_config)

    ws_client.add_event_handler(ServerMessageType.RecognitionStarted,
                                config_updater)

    with open(path_to_test_resource("ch.wav"), "rb") as audio_stream:
        ws_client.run_synchronously(audio_stream, transcription_config,
                                    audio_settings)
    mock_server.wait_for_clean_disconnects()

    set_recognition_config_msgs = mock_server.find_messages_by_type(
        "SetRecognitionConfig")
    assert len(set_recognition_config_msgs) == 1
    assert set_recognition_config_msgs[0]["transcription_config"][
        "language"] == "ja"  # noqa
def test_main_with_all_options(mock_server, tmp_path):
    vocab_file = tmp_path / "vocab.json"
    vocab_file.write_text(
        '["jabberwock", {"content": "brillig", "sounds_like": ["brillick"]}]')

    chunk_size = 1024 * 8
    audio_path = path_to_test_resource("ch.wav")

    args = [
        "-v",
        "transcribe",
        "--ssl-mode=insecure",
        "--buffer-size=256",
        "--debug",
        "--url",
        "wss://127.0.0.1:8765/v2",
        "--lang=en",
        "--output-locale=en-US",
        "--additional-vocab",
        "tumtum",
        "borogoves:boreohgofes,borrowgoafs",
        "--additional-vocab-file",
        str(vocab_file),
        "--enable-partials",
        "--punctuation-permitted-marks",
        "all",
        "--punctuation-sensitivity",
        "0.1",
        "--diarization",
        "none",
        "--speaker-change-sensitivity",
        "0.8",
        "--speaker-change-token",
        "--max-delay",
        "5.0",
        "--chunk-size",
        str(chunk_size),
        "--auth-token=xyz",
        audio_path,
    ]

    cli.main(vars(cli.parse_args(args)))
    mock_server.wait_for_clean_disconnects()

    assert mock_server.clients_connected_count == 1
    assert mock_server.clients_disconnected_count == 1
    assert mock_server.messages_received
    assert mock_server.messages_sent

    # Check that the StartRecognition message contains the correct fields
    msg = mock_server.find_start_recognition_message()
    print(msg)
    assert msg["audio_format"]["type"] == "file"
    assert len(msg["audio_format"]) == 1
    assert msg["transcription_config"]["language"] == "en"
    assert msg["transcription_config"]["output_locale"] == "en-US"
    assert msg["transcription_config"]["additional_vocab"] == ([
        "jabberwock",
        {
            "content": "brillig",
            "sounds_like": ["brillick"]
        },
        "tumtum",
        {
            "content": "borogoves",
            "sounds_like": ["boreohgofes", "borrowgoafs"]
        },
    ])
    assert mock_server.find_sent_messages_by_type("AddPartialTranscript")
    assert msg["transcription_config"]["punctuation_overrides"][
        "permitted_marks"] == [  # noqa
            "all"
        ]
    assert msg["transcription_config"]["punctuation_overrides"][
        "sensitivity"] == 0.1  # noqa
    assert msg["transcription_config"]["diarization"] == "none"
    assert msg["transcription_config"]["max_delay"] == 5.0
    assert msg["transcription_config"]["speaker_change_sensitivity"] == 0.8

    # Check that the chunk size argument is respected
    add_audio_messages = mock_server.find_add_audio_messages()
    size_of_audio_file = os.stat(audio_path).st_size
    expected_num_messages = size_of_audio_file / chunk_size
    assert -1 <= (len(add_audio_messages) - expected_num_messages) <= 1