def evaluate_files_multiclass(
    files_to_evaluate: List[os.PathLike],
    target_id: int,
    model: tf.keras.Model,
    model_settings: Dict,
):
    correct_confidences = []
    incorrect_confidences = []

    specs = [
        input_data.file2spec(model_settings, f) for f in files_to_evaluate
    ]
    specs = np.array(specs)
    preds = model.predict(np.expand_dims(specs, -1))

    # softmaxes = np.max(preds,axis=1)
    # unknown_other_words_confidences.extend(softmaxes.tolist())
    cols = np.argmax(preds, axis=1)
    # figure out how to fancy-index this later
    for row, col in enumerate(cols):
        confidence = preds[row][col]
        if col == target_id:
            correct_confidences.append(confidence)
        else:
            incorrect_confidences.append(confidence)
    return dict(correct=correct_confidences, incorrect=incorrect_confidences)
def evaluate_files_single_target(
    files_to_evaluate: List[os.PathLike],
    target_id: int,
    model: tf.keras.Model,
    model_settings: Dict,
):
    specs = [
        input_data.file2spec(model_settings, f) for f in files_to_evaluate
    ]
    specs = np.array(specs)
    preds = model.predict(np.expand_dims(specs, -1))
    return preds[:, target_id], preds
def sweep_run(sd: SweepData, q):

    # load embedding model
    traindir = Path(f"/home/mark/tinyspeech_harvard/multilingual_embedding_wc")
    with open(traindir / "unknown_files.txt", "r") as fh:
        unknown_files = fh.read().splitlines()

    base_model_path = traindir / "models" / "multilingual_context_73_0.8011"

    model_settings = input_data.standard_microspeech_model_settings(3)
    name, model, details = transfer_learning.transfer_learn(
        target=sd.target,
        train_files=sd.train_files,
        val_files=sd.val_files,
        unknown_files=unknown_files,
        num_epochs=sd.n_epochs,
        num_batches=sd.n_batches,
        batch_size=sd.batch_size,
        primary_lr=sd.primary_lr,
        backprop_into_embedding=sd.backprop_into_embedding,
        embedding_lr=sd.embedding_lr,
        model_settings=model_settings,
        base_model_path=base_model_path,
        base_model_output="dense_2",
        csvlog_dest=sd.model_dest_dir / "log.csv",
    )
    print("saving", name)
    modelpath = sd.model_dest_dir / name
    # skip saving model for now, slow
    # model.save(modelpath)

    specs = [input_data.file2spec(model_settings, f) for f in sd.val_files]
    specs = np.expand_dims(specs, -1)
    preds = model.predict(specs)
    amx = np.argmax(preds, axis=1)
    # print(amx)
    val_accuracy = amx[amx == 2].shape[0] / preds.shape[0]
    # this should maybe be thresholded also
    print("VAL ACCURACY", val_accuracy)

    start = datetime.datetime.now()
    sa.eval_stream_test(streamtarget, live_model=model)
    end = datetime.datetime.now()
    print("time elampsed (for all thresholds)", end - start)

    q.put(val_accuracy)
def evaluate_and_track(
    words_to_evaluate: List[str],
    target_id: int,
    data_dir: os.PathLike,
    utterances_per_word: int,
    model: tf.keras.Model,
    model_settings: Dict,
):
    # TODO(mmaz) rewrite and combine with evaluate_fast
    raise ValueError(
        "this only works for multiclass, see other evaluation functions")

    correct_confidences = []
    incorrect_confidences = []
    track_correct = {}
    track_incorrect = {}

    for word in words_to_evaluate:
        fs = np.random.choice(glob.glob(data_dir + word + "/*.wav"),
                              utterances_per_word,
                              replace=False)

        track_correct[word] = []
        track_incorrect[word] = []

        specs = np.array([input_data.file2spec(model_settings, f) for f in fs])
        preds = model.predict(np.expand_dims(specs, -1))

        # softmaxes = np.max(preds,axis=1)
        # unknown_other_words_confidences.extend(softmaxes.tolist())
        cols = np.argmax(preds, axis=1)
        # figure out how to fancy-index this later
        for row, col in enumerate(cols):
            confidence = preds[row][col]
            if col == target_id:
                correct_confidences.append(confidence)
                track_correct[word].append(confidence)
            else:
                incorrect_confidences.append(confidence)
                track_incorrect[word].append(confidence)
    return {
        "correct": correct_confidences,
        "incorrect": incorrect_confidences,
        "track_correct": track_correct,
        "track_incorrect": track_incorrect,
    }
def evaluate_fast_single_target(
    words_to_evaluate: List[str],
    target_id: int,
    data_dir: os.PathLike,
    utterances_per_word: int,
    model: tf.keras.Model,
    model_settings: Dict,
):
    specs = []
    for word in words_to_evaluate:
        wavs = glob.glob(data_dir + word + "/*.wav")
        if len(wavs) > utterances_per_word:
            fs = np.random.choice(wavs, utterances_per_word, replace=False)
        else:
            print("using all wavs for ", word)
            fs = wavs
        specs.extend([input_data.file2spec(model_settings, f) for f in fs])
    specs = np.array(specs)
    preds = model.predict(np.expand_dims(specs, -1))
    return preds[:, target_id], preds
def evaluate_fast_multiclass(
    words_to_evaluate: List[str],
    target_id: int,
    data_dir: os.PathLike,
    utterances_per_word: int,
    model: tf.keras.Model,
    model_settings: Dict,
):
    correct_confidences = []
    incorrect_confidences = []

    specs = []
    for word in words_to_evaluate:
        wavs = glob.glob(data_dir + word + "/*.wav")
        if len(wavs) > utterances_per_word:
            fs = np.random.choice(wavs, utterances_per_word, replace=False)
        else:
            print("using all wavs for ", word)
            fs = wavs
        specs.extend([input_data.file2spec(model_settings, f) for f in fs])
    specs = np.array(specs)
    preds = model.predict(np.expand_dims(specs, -1))

    # softmaxes = np.max(preds,axis=1)
    # unknown_other_words_confidences.extend(softmaxes.tolist())
    cols = np.argmax(preds, axis=1)
    # figure out how to fancy-index this later
    for row, col in enumerate(cols):
        confidence = preds[row][col]
        if col == target_id:
            correct_confidences.append(confidence)
        else:
            incorrect_confidences.append(confidence)
    return {
        "correct": correct_confidences,
        "incorrect": incorrect_confidences,
    }
        target=target,
        train_files=train_files,
        val_files=val_files,
        unknown_files=unknown_files,
        num_epochs=4,  # 9
        num_batches=1,  # 3
        batch_size=64,
        model_settings=model_settings,
        base_model_path=base_model_path,
        base_model_output="dense_2",
    )
    print("saving", name)
    model.save(model_dest_dir / name)
    print(":::::::: saved to", model_dest_dir / name)

# %%
# sanity check model outputs
specs = [input_data.file2spec(model_settings, f) for f in val_files]
specs = np.expand_dims(specs, -1)
print(specs.shape)
preds = model.predict(specs)
amx = np.argmax(preds, axis=1)
print(amx)
print("VAL ACCURACY", amx[amx == 2].shape[0] / preds.shape[0])
print("--")

with np.printoptions(precision=3, suppress=True):
    print(preds)

# %%
iso2val_files = {k: [] for k in iso2lang.keys()}
for ix, f in enumerate(val_files):
    lang = f[45:47]
    #word = f.split("/")[-2]
    iso2val_files[lang].append(f)
# %%
# calculate val accuracy per language
model_settings = input_data.standard_microspeech_model_settings(
    label_count=len(commands) + 1)  # add silence
for lang_isocode, vfs in iso2val_files.items():
    val_audio = []
    val_labels = []

    print(len(vfs))
    for ix, f in enumerate(vfs):
        val_audio.append(input_data.file2spec(model_settings, f))

        word = f.split("/")[-2]
        val_labels.append(commands.index(word) + 1)  # add silence
        if ix % 2000 == 0:
            print(ix)
    val_audio = np.array(val_audio)
    val_labels = np.array(val_labels)

    y_pred = np.argmax(model.predict(val_audio), axis=1)
    y_true = val_labels

    val_acc = sum(y_pred == y_true) / len(y_true)
    print(f'{lang_isocode} accuracy: {val_acc:.0%}')
    iso2valacc[lang_isocode] = val_acc