def sweep_run(sd: SweepData, q):

    # load embedding model
    traindir = Path(f"/home/mark/tinyspeech_harvard/multilingual_embedding_wc")
    with open(traindir / "unknown_files.txt", "r") as fh:
        unknown_files = fh.read().splitlines()

    base_model_path = traindir / "models" / "multilingual_context_73_0.8011"

    model_settings = input_data.standard_microspeech_model_settings(3)
    name, model, details = transfer_learning.transfer_learn(
        target=sd.target,
        train_files=sd.train_files,
        val_files=sd.val_files,
        unknown_files=unknown_files,
        num_epochs=sd.n_epochs,
        num_batches=sd.n_batches,
        batch_size=sd.batch_size,
        primary_lr=sd.primary_lr,
        backprop_into_embedding=sd.backprop_into_embedding,
        embedding_lr=sd.embedding_lr,
        model_settings=model_settings,
        base_model_path=base_model_path,
        base_model_output="dense_2",
        csvlog_dest=sd.model_dest_dir / "log.csv",
    )
    print("saving", name)
    modelpath = sd.model_dest_dir / name
    # skip saving model for now, slow
    # model.save(modelpath)

    specs = [input_data.file2spec(model_settings, f) for f in sd.val_files]
    specs = np.expand_dims(specs, -1)
    preds = model.predict(specs)
    amx = np.argmax(preds, axis=1)
    # print(amx)
    val_accuracy = amx[amx == 2].shape[0] / preds.shape[0]
    # this should maybe be thresholded also
    print("VAL ACCURACY", val_accuracy)

    start = datetime.datetime.now()
    sa.eval_stream_test(streamtarget, live_model=model)
    end = datetime.datetime.now()
    print("time elampsed (for all thresholds)", end - start)

    q.put(val_accuracy)
def eval_stream_test(st: StreamTarget, live_model=None):
    if live_model is not None:
        model = live_model
    else:
        tf.get_logger().setLevel(logging.ERROR)
        model = tf.keras.models.load_model(st.model_path)
        tf.get_logger().setLevel(logging.INFO)

    model_settings = input_data.standard_microspeech_model_settings(label_count=3)

    if os.path.isfile(st.destination_result_pkl):
        print("results already present", st.destination_result_pkl, flush=True)
        return
    print("SAVING results TO\n", st.destination_result_pkl)
    inferences_exist = False
    if os.path.isfile(st.destination_result_inferences):
        print("inferences already present", flush=True)
        loaded_inferences = np.load(st.destination_result_pkl)
        inferences_exist = True
    else:
        print("SAVING inferences TO\n", st.destination_result_inferences, flush=True)

    results = {}
    if inferences_exist:
        results[st.target_word], _ = calculate_streaming_accuracy(
            model, model_settings, st.stream_flags, loaded_inferences
        )
    else:
        results[st.target_word], inferences = calculate_streaming_accuracy(
            model, model_settings, st.stream_flags
        )

    with open(st.destination_result_pkl, "wb") as fh:
        pickle.dump(results, fh)
    if not inferences_exist:
        np.save(st.destination_result_inferences, inferences)

    # https://keras.io/api/utils/backend_utils/
    tf.keras.backend.clear_session()
        print("target was used as an unknown word")
        continue
    # assert target not in other_words, "target is present in mega_unknown_files"

    base_dir = sse / target
    assert os.path.isdir(base_dir), f"{base_dir} not present"
    assert os.path.isdir(
        base_dir /
        "n_shots"), f"shots in {base_dir} not present - generate first"
    assert os.path.isdir(
        base_dir / "val"), f"val in {base_dir} not present - generate first"
    model_dest_dir = base_dir / "model"
    os.makedirs(model_dest_dir, exist_ok=False)
    print("model dest dir", model_dest_dir)

    model_settings = input_data.standard_microspeech_model_settings(3)
    target_n_shots = os.listdir(base_dir / "n_shots")
    # N_SHOTS = 5
    # target_n_train = target_n_shots[:N_SHOTS]
    # target_n_val = target_n_shots[N_SHOTS:]
    #val_files = [str(base_dir / "n_shots" / w) for w in target_n_val]
    val_names = os.listdir(base_dir / "val")

    train_files = [str(base_dir / "n_shots" / w) for w in target_n_shots]
    val_files = [str(base_dir / "val" / w) for w in val_names]
    print("---TRAIN---", len(train_files))
    print("\n".join(train_files))
    print("----VAL--", len(val_files))
    print("\n".join(val_files))

    if val_files == []:
    num_targets = len(os.listdir(data_dir / target))
    print("n targets", num_targets)

    some_commands = np.random.choice(commands, 30, replace=False)
    some_unknown = np.random.choice(unknown_words, 30, replace=False)
    some_oov = np.random.choice(oov_words, 30, replace=False)
    flat = set([w for l in [some_commands, some_unknown, some_oov] for w in l])
    unknown_sample = list(flat.difference({target}))
    print(unknown_sample, len(unknown_sample))

    # load model
    tf.get_logger().setLevel(logging.ERROR)
    model = tf.keras.models.load_model(model_path)
    tf.get_logger().setLevel(logging.INFO)

    model_settings = input_data.standard_microspeech_model_settings(label_count=3)

    target_results, all_preds_target = transfer_learning.evaluate_fast_single_target(
        [target], 2, str(data_dir) + "/", num_targets, model, model_settings
    )
    # note: passing in the _TARGET_ category ID (2) for negative examples too:
    # we ignore other categories altogether
    unknown_results, all_preds_unknown = transfer_learning.evaluate_fast_single_target(
        unknown_sample, 2, str(data_dir) + "/", 300, model, model_settings
    )

    # tprs, fprs, thresh_labels = roc_sc(target_results, unknown_results)
    results = dict(
        target_results=target_results,
        unknown_results=unknown_results,
        all_predictions_targets=all_preds_target,
print(
    df.to_latex(
        header=True,
        index=False,
        float_format="%.2f",
        label="tab:embacc",
    ))
# %%
iso2val_files = {k: [] for k in iso2lang.keys()}
for ix, f in enumerate(val_files):
    lang = f[45:47]
    #word = f.split("/")[-2]
    iso2val_files[lang].append(f)
# %%
# calculate val accuracy per language
model_settings = input_data.standard_microspeech_model_settings(
    label_count=len(commands) + 1)  # add silence
for lang_isocode, vfs in iso2val_files.items():
    val_audio = []
    val_labels = []

    print(len(vfs))
    for ix, f in enumerate(vfs):
        val_audio.append(input_data.file2spec(model_settings, f))

        word = f.split("/")[-2]
        val_labels.append(commands.index(word) + 1)  # add silence
        if ix % 2000 == 0:
            print(ix)
    val_audio = np.array(val_audio)
    val_labels = np.array(val_labels)
def run_transfer_learning(td: TargetData):
    import tensorflow as tf

    model_settings = input_data.standard_microspeech_model_settings(label_count=3)

    # fmt: off
    result_pkl = Path(td.dest_dir) / "results" / f"{td.lang_ix:02d}_{td.target_ix:03d}_{td.target_lang}_{td.target_word}.pkl"
    csvlog_dest = str(Path(td.dest_dir) / "models" / f"{td.target_word}_trainlog.csv")
    # fmt: on

    assert not os.path.isfile(csvlog_dest), f"{csvlog_dest} csvlog already exists"
    assert not os.path.exists(result_pkl), f"{result_pkl} exists"

    # TODO(mmaz): use keras class weights?
    name, model, details = transfer_learning.transfer_learn(
        target=td.target_word,
        train_files=td.train_files,
        val_files=td.val_files,
        unknown_files=td.unknown_files,
        num_epochs=4,
        num_batches=1,
        batch_size=64,
        model_settings=model_settings,
        csvlog_dest=csvlog_dest,
        base_model_path=td.base_model_path,
        base_model_output=td.base_model_output,
    )
    # fmt: off
    save_dest = Path(td.dest_dir) / "models" / f"{td.lang_ix:02d}_{td.target_ix:03d}_{td.target_lang}_{td.target_word}__{name}"
    # fmt: on
    print("SAVING", save_dest, flush=True)
    model.save(save_dest)

    target_results, all_preds_target = transfer_learning.evaluate_files_single_target(
        td.target_wavs, 2, model, model_settings
    )
    # note: passing in the _TARGET_ category ID (2) for negative examples too:
    # we ignore other categories altogether
    unknown_results, all_preds_unknown = transfer_learning.evaluate_files_single_target(
        td.unknown_sample, 2, model, model_settings
    )

    results = dict(
        target_results=target_results,
        unknown_results=unknown_results,
        all_predictions_targets=all_preds_target,
        all_predictions_unknown=all_preds_unknown,
        details=details,
        target_word=td.target_word,
        target_lang=td.target_lang,
        train_files=td.train_files,
        oov_words=oov_lang_words,
        commands=commands,
        target_data=asdict(td),
    )

    with open(result_pkl, "wb",) as fh:
        pickle.dump(results, fh)

    # https://keras.io/api/utils/backend_utils/
    tf.keras.backend.clear_session()

    return