import sklearn.metrics
import sys
import time

import localmodule

# Read command-line arguments.
args = sys.argv[1:]
aug_kind_str = args[0]
test_unit_str = args[1]
predict_unit_str = args[2]
trial_id = int(args[3])

# Define constants.
data_dir = localmodule.get_data_dir()
dataset_name = localmodule.get_dataset_name()
folds = localmodule.fold_units()
models_dir = localmodule.get_models_dir()
units = localmodule.get_units()
model_name = "pcen-convnet"
if not aug_kind_str == "none":
    model_name = "_".join([model_name, "aug-" + aug_kind_str])
model_dir = os.path.join(models_dir, model_name)
tolerance = 0.5  # in seconds
min_dist = 3  # 150 ms

# Define thresholds.
icassp_thresholds = 1.0 - np.concatenate(
    (np.logspace(-9, -2, 141), np.delete(np.logspace(-2, 0, 81), 0)))
n_thresholds = len(icassp_thresholds)
def multiplex_lms_with_background(aug_kind_str, fold_units, n_input_hops,
                                  batch_size):

    # Define constants.
    aug_dict = localmodule.get_augmentations()
    data_dir = localmodule.get_data_dir()
    dataset_name = localmodule.get_dataset_name()
    tfr_name = "_".join([dataset_name, "clip-logmelspec"])
    tfr_dir = os.path.join(data_dir, tfr_name)
    bg_name = "_".join([dataset_name, "clip-logmelspec-backgrounds"])
    bg_dir = os.path.join(data_dir, bg_name)
    T_str = "T-" + str(bg_duration).zfill(4)
    T_dir = os.path.join(bg_dir, T_str)

    # Parse augmentation kind string (aug_kind_str).
    if aug_kind_str == "none":
        augs = ["original"]
    elif aug_kind_str == "pitch":
        augs = ["original", "pitch"]
    elif aug_kind_str == "stretch":
        augs = ["original", "stretch"]
    elif aug_kind_str == "all-but-noise":
        augs = ["original", "pitch", "stretch"]
    else:
        noise_augs = ["noise-" + unit_str for unit_str in fold_units]
        if aug_kind_str == "all":
            augs = noise_augs + ["original", "pitch", "stretch"]
        elif aug_kind_str == "noise":
            augs = noise_augs + ["original"]

    # Loop over augmentations.
    streams = []
    for aug_str in augs:

        # Define instances.
        aug_dir = os.path.join(tfr_dir, aug_str)
        if aug_str == "original":
            instances = [aug_str]
        else:
            n_instances = aug_dict[aug_str]
            instances = [
                "-".join([aug_str, str(instance_id)])
                for instance_id in range(n_instances)
            ]

        # Define bias.
        if aug_str[:5] == "noise":
            bias = np.float32(-17.0)
        else:
            bias = np.float32(0.0)

        # Loop over instances.
        for instanced_aug_str in instances:

            # Loop over units.
            for unit_str in fold_units:

                # Define path to time-frequency representation.
                lms_name = "_".join(
                    [dataset_name, instanced_aug_str, unit_str])
                lms_path = os.path.join(aug_dir, lms_name + ".hdf5")

                # Define path to background.
                bg_name = "_".join([
                    dataset_name, "background_summaries", unit_str,
                    T_str + ".hdf5"
                ])
                bg_path = os.path.join(T_dir, bg_name)

                # Define pescador streamer.
                stream = pescador.Streamer(yield_lms_and_background, lms_path,
                                           n_input_hops, bias, bg_path)
                streams.append(stream)

    # Multiplex streamers together.
    mux = pescador.Mux(streams,
                       k=len(streams),
                       lam=None,
                       with_replacement=True,
                       revive=True)

    # Create buffered streamer with specified batch size.
    buffered_streamer = pescador.BufferedStreamer(mux, batch_size)

    return pescador.maps.keras_tuples(buffered_streamer,
                                      inputs=["X_spec", "X_bg"],
                                      outputs=["y"])