示例#1
0
def dev_tuner(full_csv, split_type):
    """Tuning n_components on dev set.
    Note: This is a very basic tunning for nn based affinity.
    This is work in progress till we find a better way.
    """

    DER_list = []
    pval = None
    for n_lambdas in range(1, params["max_num_spkrs"] + 1):

        # Process whole dataset for value of n_lambdas
        concate_rttm_file = diarize_dataset(full_csv, split_type, n_lambdas,
                                            pval)

        ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm")
        sys_rttm = concate_rttm_file
        [MS, FA, SER, DER_] = DER(
            ref_rttm,
            sys_rttm,
            params["ignore_overlap"],
            params["forgiveness_collar"],
        )

        DER_list.append(DER_)

    # Take n_lambdas with minmum DER
    tuned_n_lambdas = DER_list.index(min(DER_list)) + 1

    return tuned_n_lambdas
示例#2
0
def dev_p_tuner(full_csv, split_type):
    """Tuning p_value affinity matrix
    """

    DER_list = []
    prange = np.arange(0.002, 0.015, 0.001)

    n_lambdas = None
    for p_v in prange:
        # Process whole dataset for value of p_v
        concate_rttm_file = diarize_dataset(full_csv, split_type, n_lambdas,
                                            p_v)

        ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm")
        sys_rttm = concate_rttm_file
        [MS, FA, SER, DER_] = DER(
            ref_rttm,
            sys_rttm,
            params["ignore_overlap"],
            params["forgiveness_collar"],
        )

        DER_list.append(DER_)

    # Take p_val that gave minmum DER on Dev dataset
    tuned_p_val = prange[DER_list.index(min(DER_list))]

    return tuned_p_val
示例#3
0
def dev_nn_tuner(full_csv, split_type):
    """Tuning n_neighbors on dev set.
    Assuming oracle num of speakers.
    """

    DER_list = []
    pval = None
    for nn in range(5, 15):

        # Fix this later. Now assumming oracle num of speakers
        n_lambdas = 4

        # Process whole dataset for value of n_lambdas
        concate_rttm_file = diarize_dataset(full_csv, split_type, n_lambdas,
                                            pval, nn)

        ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm")
        sys_rttm = concate_rttm_file
        [MS, FA, SER, DER_] = DER(
            ref_rttm,
            sys_rttm,
            params["ignore_overlap"],
            params["forgiveness_collar"],
        )

        DER_list.append([nn, DER_])

    DER_list.sort(key=lambda x: x[1])
    tunned_nn = DER_list[0]

    return tunned_nn[0]
示例#4
0
def dev_nn_tuner(full_meta, split_type):
    """Tuning n_neighbors on dev set. Assuming oracle num of speakers.
    This is used when nn based affinity is selected.
    """

    DER_list = []
    pval = None

    # Now assumming oracle num of speakers.
    n_lambdas = 4

    for nn in range(5, 15):

        # Process whole dataset for value of n_lambdas.
        concate_rttm_file = diarize_dataset(full_meta, split_type, n_lambdas,
                                            pval, nn)

        ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm")
        sys_rttm = concate_rttm_file
        [MS, FA, SER, DER_] = DER(
            ref_rttm,
            sys_rttm,
            params["ignore_overlap"],
            params["forgiveness_collar"],
        )

        DER_list.append([nn, DER_])

        if params["oracle_n_spkrs"] is True and params["backend"] == "kmeans":
            break

    DER_list.sort(key=lambda x: x[1])
    tunned_nn = DER_list[0]

    return tunned_nn[0]
示例#5
0
def dev_ahc_threshold_tuner(full_meta, split_type):
    """Tuning threshold for affinity matrix. This function is called when AHC is used as backend.
    """

    DER_list = []
    prange = np.arange(0.0, 1.0, 0.1)

    n_lambdas = None  # using it as flag later.

    # Note: p_val is threshold in case of AHC.
    for p_v in prange:
        # Process whole dataset for value of p_v.
        concate_rttm_file = diarize_dataset(full_meta, split_type, n_lambdas,
                                            p_v)

        ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm")
        sys_rttm = concate_rttm_file
        [MS, FA, SER, DER_] = DER(
            ref_rttm,
            sys_rttm,
            params["ignore_overlap"],
            params["forgiveness_collar"],
        )

        DER_list.append(DER_)

        if params["oracle_n_spkrs"] is True:
            break  # no need of threshold search.

    # Take p_val that gave minmum DER on Dev dataset.
    tuned_p_val = prange[DER_list.index(min(DER_list))]

    return tuned_p_val
示例#6
0
def dev_pval_tuner(full_meta, split_type):
    """Tuning p_value for affinity matrix.
    The p_value used so that only p% of the values in each row is retained.
    """

    DER_list = []
    prange = np.arange(0.002, 0.015, 0.001)

    n_lambdas = None  # using it as flag later.
    for p_v in prange:
        # Process whole dataset for value of p_v.
        concate_rttm_file = diarize_dataset(full_meta, split_type, n_lambdas,
                                            p_v)

        ref_rttm = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm")
        sys_rttm = concate_rttm_file
        [MS, FA, SER, DER_] = DER(
            ref_rttm,
            sys_rttm,
            params["ignore_overlap"],
            params["forgiveness_collar"],
        )

        DER_list.append(DER_)

        if params["oracle_n_spkrs"] is True and params["backend"] == "kmeans":
            # no need of p_val search. Note p_val is needed for SC for both oracle and est num of speakers.
            # p_val is needed in oracle_n_spkr=False when using kmeans backend.
            break

    # Take p_val that gave minmum DER on Dev dataset.
    tuned_p_val = prange[DER_list.index(min(DER_list))]

    return tuned_p_val
示例#7
0
    out_boundaries = diarize_dataset(
        full_csv,
        "dev",
        n_lambdas=n_lambdas,
        pval=best_pval,
        n_neighbors=best_nn,
    )

    # Evaluating on DEV set
    logger.info("Evaluating for AMI Dev. set")
    ref_rttm_dev = os.path.join(params["ref_rttm_dir"], "fullref_ami_dev.rttm")
    sys_rttm_dev = out_boundaries
    [MS_dev, FA_dev, SER_dev, DER_dev] = DER(
        ref_rttm_dev,
        sys_rttm_dev,
        params["ignore_overlap"],
        params["forgiveness_collar"],
        individual_file_scores=True,
    )
    msg = "AMI Dev set: Diarization Error Rate = %s %%\n" % (str(
        round(DER_dev[-1], 2)))
    logger.info(msg)

    # AMI Eval Set
    full_csv = []
    with open(params["csv_diary_eval"], "r") as csv_file:
        reader = csv.reader(csv_file, delimiter=",")
        for row in reader:
            full_csv.append(row)

    out_boundaries = diarize_dataset(
示例#8
0
            split_type,
            n_lambdas=n_lambdas,
            pval=best_pval,
            n_neighbors=best_nn,
        )

        # Computing DER.
        msg = "Computing DERs for " + split_type + " set"
        logger.info(msg)
        ref_rttm = os.path.join(params["ref_rttm_dir"],
                                "fullref_ami_" + split_type + ".rttm")
        sys_rttm = out_boundaries
        [MS, FA, SER, DER_vals] = DER(
            ref_rttm,
            sys_rttm,
            params["ignore_overlap"],
            params["forgiveness_collar"],
            individual_file_scores=True,
        )

        # Writing DER values to a file. Append tag.
        der_file_name = split_type + "_DER_" + tag
        out_der_file = os.path.join(params["der_dir"], der_file_name)
        msg = "Writing DER file to: " + out_der_file
        logger.info(msg)
        diar.write_ders_file(ref_rttm, DER_vals, out_der_file)

        msg = ("AMI " + split_type + " set DER = %s %%\n" %
               (str(round(DER_vals[-1], 2))))
        logger.info(msg)
        final_DERs[split_type] = round(DER_vals[-1], 2)