def test_cli_arg_parse(args, values):
    required_args = ["--url=example", "file"]
    test_args = args + required_args
    actual_values = vars(cli.parse_args(args=test_args))

    for (key, val) in values.items():
        assert actual_values[key] == val
def test_main_with_basic_options(mock_server):
    args = [
        "-vv",
        "transcribe",
        "--ssl-mode=insecure",
        "--url",
        mock_server.url,
        path_to_test_resource("ch.wav"),
    ]
    cli.main(vars(cli.parse_args(args)))
    mock_server.wait_for_clean_disconnects()

    assert mock_server.clients_connected_count == 1
    assert mock_server.clients_disconnected_count == 1
    assert mock_server.messages_received
    assert mock_server.messages_sent
    assert mock_server.connection_request.path == "/v2"
def test_main_with_all_options(mock_server, tmp_path):
    vocab_file = tmp_path / "vocab.json"
    vocab_file.write_text(
        '["jabberwock", {"content": "brillig", "sounds_like": ["brillick"]}]')

    chunk_size = 1024 * 8
    audio_path = path_to_test_resource("ch.wav")

    args = [
        "-v",
        "transcribe",
        "--ssl-mode=insecure",
        "--buffer-size=256",
        "--debug",
        "--url",
        "wss://127.0.0.1:8765/v2",
        "--lang=en",
        "--output-locale=en-US",
        "--additional-vocab",
        "tumtum",
        "borogoves:boreohgofes,borrowgoafs",
        "--additional-vocab-file",
        str(vocab_file),
        "--enable-partials",
        "--punctuation-permitted-marks",
        "all",
        "--punctuation-sensitivity",
        "0.1",
        "--diarization",
        "none",
        "--speaker-change-sensitivity",
        "0.8",
        "--speaker-change-token",
        "--max-delay",
        "5.0",
        "--chunk-size",
        str(chunk_size),
        "--auth-token=xyz",
        audio_path,
    ]

    cli.main(vars(cli.parse_args(args)))
    mock_server.wait_for_clean_disconnects()

    assert mock_server.clients_connected_count == 1
    assert mock_server.clients_disconnected_count == 1
    assert mock_server.messages_received
    assert mock_server.messages_sent

    # Check that the StartRecognition message contains the correct fields
    msg = mock_server.find_start_recognition_message()
    print(msg)
    assert msg["audio_format"]["type"] == "file"
    assert len(msg["audio_format"]) == 1
    assert msg["transcription_config"]["language"] == "en"
    assert msg["transcription_config"]["output_locale"] == "en-US"
    assert msg["transcription_config"]["additional_vocab"] == ([
        "jabberwock",
        {
            "content": "brillig",
            "sounds_like": ["brillick"]
        },
        "tumtum",
        {
            "content": "borogoves",
            "sounds_like": ["boreohgofes", "borrowgoafs"]
        },
    ])
    assert mock_server.find_sent_messages_by_type("AddPartialTranscript")
    assert msg["transcription_config"]["punctuation_overrides"][
        "permitted_marks"] == [  # noqa
            "all"
        ]
    assert msg["transcription_config"]["punctuation_overrides"][
        "sensitivity"] == 0.1  # noqa
    assert msg["transcription_config"]["diarization"] == "none"
    assert msg["transcription_config"]["max_delay"] == 5.0
    assert msg["transcription_config"]["speaker_change_sensitivity"] == 0.8

    # Check that the chunk size argument is respected
    add_audio_messages = mock_server.find_add_audio_messages()
    size_of_audio_file = os.stat(audio_path).st_size
    expected_num_messages = size_of_audio_file / chunk_size
    assert -1 <= (len(add_audio_messages) - expected_num_messages) <= 1