예제 #1
0
def calculate_SELD_metrics(gt_meta_dir, pred_meta_dir, score_type):
    '''Calculate metrics using official tool. This part of code is modified from:
    https://github.com/sharathadavanne/seld-dcase2019/blob/master/calculate_SELD_metrics.py
    
    Args:
      gt_meta_dir: ground truth meta directory. 
      pred_meta_dir: prediction meta directory.
      score_type: 'all', 'split', 'ov', 'ir'
      
    Returns:
      metrics: dict
    '''

    # Load feature class
    feat_cls = cls_feature_class.FeatureClass()

    # collect gt files info
    # gt_meta_files = [fn for fn in os.listdir(gt_meta_dir) if fn.endswith('.csv') and not fn.startswith('.')]

    # collect pred files info
    pred_meta_files = [
        fn for fn in os.listdir(pred_meta_dir)
        if fn.endswith('.csv') and not fn.startswith('.')
    ]

    # Load evaluation metric class
    eval = evaluation_metrics.SELDMetrics(nb_frames_1s=feat_cls.nb_frames_1s(),
                                          data_gen=feat_cls)

    # Calculate scores for different splits, overlapping sound events, and impulse responses (reverberant scenes)
    # score_type = 'all', 'split', 'ov', 'ir'
    split_cnt_dict = get_nb_files(pred_meta_files, _group=score_type)

    sed_error_rate = []
    sed_f1_score = []
    doa_error = []
    doa_frame_recall = []
    seld_metric = []

    # Calculate scores across files for a given score_type
    for split_key in np.sort(list(split_cnt_dict)):
        eval.reset()  # Reset the evaluation metric parameters
        for _, pred_file in enumerate(split_cnt_dict[split_key]):
            # Load predicted output format file
            pred_dict = evaluation_metrics.load_output_format_file(
                os.path.join(pred_meta_dir, pred_file))

            # Load reference description file
            gt_desc_file_dict = feat_cls.read_desc_file(
                os.path.join(gt_meta_dir, pred_file.replace('.npy', '.csv')))

            # Generate classification labels for SELD
            gt_labels = feat_cls.get_clas_labels_for_file(gt_desc_file_dict)
            pred_labels = evaluation_metrics.output_format_dict_to_classification_labels(
                pred_dict, feat_cls)

            # Calculated SED and DOA scores
            eval.update_sed_scores(pred_labels.max(2), gt_labels.max(2))
            eval.update_doa_scores(pred_labels, gt_labels)

        # Overall SED and DOA scores
        sed_er, sed_f1 = eval.compute_sed_scores()
        doa_err, doa_fr = eval.compute_doa_scores()
        seld_metr = evaluation_metrics.compute_seld_metric([sed_er, sed_f1],
                                                           [doa_err, doa_fr])

        sed_error_rate.append(sed_er)
        sed_f1_score.append(sed_f1)
        doa_error.append(doa_err)
        doa_frame_recall.append(doa_fr)
        seld_metric.append(seld_metr)

    sed_scores = [sed_error_rate, sed_f1_score]
    doa_er_metric = [doa_error, doa_frame_recall]

    sed_scores = np.array(sed_scores).squeeze()
    doa_er_metric = np.array(doa_er_metric).squeeze()
    seld_metric = np.array(seld_metric).squeeze()

    return sed_scores, doa_er_metric, seld_metric
azi_list, ele_list = feat_cls.get_azi_ele_list()

# collect reference files info
ref_files = os.listdir(ref_desc_files)
nb_ref_files = len(ref_files)

# collect predicted files info
pred_files = os.listdir(pred_output_format_files)
nb_pred_files = len(pred_files)

if nb_ref_files != nb_pred_files:
    print('ERROR: Mismatch. Reference has {} and prediction has {} files'.format(nb_ref_files, nb_pred_files))
    exit()

# Load evaluation metric class
eval = evaluation_metrics.SELDMetrics(nb_frames_1s=feat_cls.nb_frames_1s(), data_gen=feat_cls)

# Calculate scores for different splits, overlapping sound events, and impulse responses (reverberant scenes)
score_type_list = [ 'all', 'split', 'ov', 'ir']

print('\nCalculating {} scores for {}'.format(score_type_list, os.path.basename(pred_output_format_files)))

for score_type in score_type_list:
    print('\n\n---------------------------------------------------------------------------------------------------')
    print('------------------------------------  {}   ---------------------------------------------'.format('Total score' if score_type=='all' else 'score per {}'.format(score_type)))
    print('---------------------------------------------------------------------------------------------------')

    split_cnt_dict = get_nb_files(pred_files, _group=score_type) # collect files corresponding to score_type

    # Calculate scores across files for a given score_type
    for split_key in np.sort(list(split_cnt_dict)):