示例#1
0
文件: inference.py 项目: fgnt/pb_sed
def config():
    debug = False
    timestamp = timeStamped('')[1:] + ('_debug' if debug else '')

    strong_label_crnn_hyper_params_dir = ''
    assert len(
        strong_label_crnn_hyper_params_dir
    ) > 0, 'Set strong_label_crnn_hyper_params_dir on the command line.'
    strong_label_crnn_tuning_config = load_json(
        Path(strong_label_crnn_hyper_params_dir) / '1' / 'config.json')
    strong_label_crnn_dirs = strong_label_crnn_tuning_config[
        'strong_label_crnn_dirs']
    assert len(strong_label_crnn_dirs
               ) > 0, 'strong_label_crnn_dirs must not be empty.'
    strong_label_crnn_checkpoints = strong_label_crnn_tuning_config[
        'strong_label_crnn_checkpoints']
    data_provider = strong_label_crnn_tuning_config['data_provider']
    database_name = strong_label_crnn_tuning_config['database_name']
    storage_dir = str(storage_root / 'strong_label_crnn' / database_name /
                      'inference' / timestamp)
    assert not Path(storage_dir).exists()

    weak_label_crnn_hyper_params_dir = strong_label_crnn_tuning_config[
        'weak_label_crnn_hyper_params_dir']
    assert len(
        weak_label_crnn_hyper_params_dir
    ) > 0, 'Set weak_label_crnn_hyper_params_dir on the command line.'
    weak_label_crnn_tuning_config = load_json(
        Path(weak_label_crnn_hyper_params_dir) / '1' / 'config.json')
    weak_label_crnn_dirs = weak_label_crnn_tuning_config['crnn_dirs']
    assert len(
        weak_label_crnn_dirs) > 0, 'weak_label_crnn_dirs must not be empty.'
    weak_label_crnn_checkpoints = weak_label_crnn_tuning_config[
        'crnn_checkpoints']

    del strong_label_crnn_tuning_config
    del weak_label_crnn_tuning_config

    sed_hyper_params_name = ['f', 'psds1', 'psds2']

    device = 0

    dataset_name = 'eval_public'
    ground_truth_filepath = None

    max_segment_length = None
    if max_segment_length is None:
        segment_overlap = None
    else:
        segment_overlap = 100
    save_scores = False
    save_detections = False

    weak_pseudo_labeling = False
    strong_pseudo_labeling = False
    pseudo_labelled_dataset_name = dataset_name

    pseudo_widening = .0

    ex.observers.append(FileStorageObserver.create(storage_dir))
示例#2
0
文件: tuning.py 项目: fgnt/pb_sed
def config():
    debug = False
    timestamp = timeStamped('')[1:] + ('_debug' if debug else '')

    weak_label_crnn_hyper_params_dir = ''
    assert len(
        weak_label_crnn_hyper_params_dir
    ) > 0, 'Set weak_label_crnn_hyper_params_dir on the command line.'
    weak_label_crnn_tuning_config = load_json(
        Path(weak_label_crnn_hyper_params_dir) / '1' / 'config.json')
    weak_label_crnn_dirs = weak_label_crnn_tuning_config['crnn_dirs']
    assert len(
        weak_label_crnn_dirs) > 0, 'weak_label_crnn_dirs must not be empty.'
    weak_label_crnn_checkpoints = weak_label_crnn_tuning_config[
        'crnn_checkpoints']
    del weak_label_crnn_tuning_config

    strong_label_crnn_group_dir = ''
    if isinstance(strong_label_crnn_group_dir, list):
        strong_label_crnn_dirs = sorted([
            str(d) for g in strong_label_crnn_group_dir
            for d in Path(g).glob('202*') if d.is_dir()
        ])
    else:
        strong_label_crnn_dirs = sorted([
            str(d) for d in Path(strong_label_crnn_group_dir).glob('202*')
            if d.is_dir()
        ])
    assert len(strong_label_crnn_dirs
               ) > 0, 'strong_label_crnn_dirs must not be empty.'
    strong_label_crnn_checkpoints = 'ckpt_best_macro_fscore_strong.pth'
    strong_crnn_config = load_json(
        Path(strong_label_crnn_dirs[0]) / '1' / 'config.json')
    data_provider = strong_crnn_config['data_provider']
    database_name = strong_crnn_config.get('database_name', 'desed')
    storage_dir = str(storage_root / 'strong_label_crnn' / database_name /
                      'hyper_params' / timestamp)
    assert not Path(storage_dir).exists()
    del strong_crnn_config
    data_provider['min_audio_length'] = .01
    data_provider['cached_datasets'] = None

    device = 0

    validation_set_name = 'validation'
    validation_ground_truth_filepath = None
    eval_set_name = 'eval_public'
    eval_ground_truth_filepath = None

    medfilt_lengths = [31] if debug else [
        301, 251, 201, 151, 101, 81, 61, 51, 41, 31, 21, 11
    ]

    ex.observers.append(FileStorageObserver.create(storage_dir))
示例#3
0
def config():
    debug = False
    timestamp = timeStamped('')[1:] + ('_debug' if debug else '')

    group_dir = ''
    if isinstance(group_dir, list):
        crnn_dirs = sorted([
            str(d) for g in group_dir for d in Path(g).glob('202*')
            if d.is_dir()
        ])
    else:
        crnn_dirs = sorted(
            [str(d) for d in Path(group_dir).glob('202*') if d.is_dir()])
    assert len(crnn_dirs) > 0, 'crnn_dirs must not be empty.'
    crnn_checkpoints = 'ckpt_best_macro_fscore_weak.pth'
    crnn_config = load_json(Path(crnn_dirs[0]) / '1' / 'config.json')
    data_provider = crnn_config['data_provider']
    database_name = crnn_config.get('database_name', 'desed')
    storage_dir = str(storage_root / 'weak_label_crnn' / database_name /
                      'hyper_params' / timestamp)
    assert not Path(storage_dir).exists()
    del crnn_config
    data_provider['min_audio_length'] = .01
    data_provider['cached_datasets'] = None

    device = 0

    validation_set_name = 'validation'
    validation_ground_truth_filepath = None
    eval_set_name = 'eval_public'
    eval_ground_truth_filepath = None

    boundaries_filter_lengths = [20] if debug else [
        100, 80, 60, 50, 40, 30, 20, 10, 0
    ]

    tune_detection_scenario_1 = True
    detection_window_lengths_scenario_1 = [11] if debug else [
        51, 41, 31, 21, 11
    ]
    detection_window_shift_scenario_1 = 1
    detection_medfilt_lengths_scenario_1 = [11] if debug else [
        101, 81, 61, 51, 41, 31, 21, 11
    ]

    tune_detection_scenario_2 = True
    detection_window_lengths_scenario_2 = [250]
    detection_window_shift_scenario_2 = 250
    detection_medfilt_lengths_scenario_2 = [1]

    ex.observers.append(FileStorageObserver.create(storage_dir))
示例#4
0
文件: inference.py 项目: fgnt/pb_sed
def tagging(
        crnns,
        dataset,
        device,
        timestamps,
        event_classes,
        hyper_params_dir,
        ground_truth,
        audio_durations,
        psds_params=(),
        max_segment_length=None,
        segment_overlap=None,
):
    print()
    print('Tagging')
    hyper_params = load_json(
        Path(hyper_params_dir) / 'tagging_hyper_params_f.json')
    thresholds = {
        event_class: hyper_params[event_class]['threshold']
        for event_class in hyper_params
    }
    tagging_scores = base.tagging(
        crnns,
        dataset,
        device,
        max_segment_length=max_segment_length,
        segment_overlap=segment_overlap,
        merge_score_segments=False,
    )
    results = {}
    if ground_truth is not None:
        tagging_scores_merged = merge_segments(tagging_scores,
                                               segment_overlap=0)
        tagging_scores_df = base.scores_to_dataframes(
            tagging_scores_merged,
            timestamps=timestamps,
            event_classes=event_classes,
        )
        if ground_truth:
            f, p, r, stats = clip_based.fscore(tagging_scores_df,
                                               ground_truth,
                                               thresholds,
                                               num_jobs=8)
            print('f', f)
            print('p', p)
            print('r', r)
            for key in f:
                results.update({
                    f'{key}_f': f[key],
                    f'{key}_p': p[key],
                    f'{key}_r': r[key],
                })
            for j in range(len(psds_params)):
                psds, psd_roc, classwise_rocs = intersection_based.psds(
                    tagging_scores_df,
                    ground_truth,
                    audio_durations,
                    **psds_params[j],
                    num_jobs=8,
                )
                print(f'psds[{j}]', psds)
                results[f'psds[{j}]'] = psds
                for event_class, (tpr, efpr, *_) in classwise_rocs.items():
                    results[f'{event_class}_auc[{j}]'] = staircase_auc(
                        tpr, efpr, psds_params[j].get('max_efpr', 100))
                psds, psd_roc, classwise_rocs = intersection_based.reference.approximate_psds(
                    tagging_scores_df,
                    ground_truth,
                    audio_durations,
                    **psds_params[j],
                    thresholds=np.linspace(.01, .99, 50),
                )
                print(f'approx_psds[{j}]', psds)
                results[f'approx_psds[{j}]'] = psds
                for event_class, (tpr, efpr, *_) in classwise_rocs.items():
                    results[f'{event_class}_approx_auc[{j}]'] = staircase_auc(
                        tpr, efpr, psds_params[j].get('max_efpr', 100))

    thresholds = np.array(
        [thresholds[event_class] for event_class in event_classes])
    tagging_scores = {
        audio_id: tagging_scores[audio_id][0]
        for audio_id in tagging_scores.keys()
    }
    tags = {
        audio_id: tagging_scores[audio_id] > thresholds
        for audio_id in tagging_scores.keys()
    }
    return tags, tagging_scores, results
示例#5
0
文件: inference.py 项目: fgnt/pb_sed
def sound_event_detection(
    crnns,
    dataset,
    device,
    timestamps,
    event_classes,
    tags,
    hyper_params_dir,
    hyper_params_name,
    ground_truth,
    audio_durations,
    collar_based_params=(),
    psds_params=(),
    max_segment_length=None,
    segment_overlap=None,
    pseudo_widening=.0,
    score_storage_dir=None,
    detection_storage_dir=None,
):
    print()
    print('Sound Event Detection')
    if isinstance(hyper_params_name, (str, Path)):
        hyper_params_name = [hyper_params_name]
    assert isinstance(hyper_params_name, (list, tuple))
    hyper_params = [
        load_json(Path(hyper_params_dir) / f'sed_hyper_params_{name}.json')
        for name in hyper_params_name
    ]

    if isinstance(score_storage_dir, (list, tuple)):
        assert len(score_storage_dir) == len(hyper_params), (
            len(score_storage_dir), len(hyper_params))
    elif isinstance(score_storage_dir, (str, Path)):
        score_storage_dir = [
            Path(score_storage_dir) / name for name in hyper_params_name
        ]
    elif score_storage_dir is not None:
        raise ValueError('score_storage_dir must be list, str, Path or None.')

    if isinstance(detection_storage_dir, (list, tuple)):
        assert len(detection_storage_dir) == len(hyper_params), (
            len(detection_storage_dir), len(hyper_params))
    elif isinstance(detection_storage_dir, (str, Path)):
        detection_storage_dir = [
            Path(detection_storage_dir) / name for name in hyper_params_name
        ]
    elif detection_storage_dir is not None:
        raise ValueError(
            'detection_storage_dir must be list, str, Path or None.')

    window_lengths = np.zeros((len(hyper_params), len(event_classes)))
    medfilt_lengths = np.zeros((len(hyper_params), len(event_classes)))
    tag_masked = np.zeros((len(hyper_params), len(event_classes)))
    window_shift = set()
    for i, hyper_params_i in enumerate(hyper_params):
        for j, event_class in enumerate(event_classes):
            window_lengths[i, j] = hyper_params_i[event_class]['window_length']
            medfilt_lengths[i,
                            j] = hyper_params_i[event_class]['medfilt_length']
            tag_masked[i, j] = hyper_params_i[event_class]['tag_masked']
            window_shift.add(hyper_params_i[event_class]['window_shift'])
    if not len(window_shift) == 1:
        raise ValueError(
            'Inference with multiple window shifts is not supported.')
    window_shift = list(window_shift)[0]
    if max_segment_length is not None:
        assert max_segment_length % window_shift == 0, (max_segment_length,
                                                        window_shift)
        assert (segment_overlap // 2) % window_shift == 0, (segment_overlap,
                                                            window_shift)
    detection_scores = base.sound_event_detection(
        crnns,
        dataset,
        device,
        model_kwargs={
            'window_length': window_lengths,
            'window_shift': window_shift
        },
        medfilt_length=medfilt_lengths,
        apply_mask=tag_masked,
        masks=tags,
        timestamps=timestamps[::window_shift],
        event_classes=event_classes,
        max_segment_length=max_segment_length,
        segment_overlap=segment_overlap,
        merge_score_segments=True,
        score_segment_overlap=segment_overlap // window_shift,
        score_storage_dir=score_storage_dir,
    )
    event_detections = []
    results = []
    for i, name in enumerate(hyper_params_name):
        if ground_truth:
            print()
            print(name)
        results.append({})
        if detection_storage_dir and detection_storage_dir[i]:
            io.write_detections_for_multiple_thresholds(
                detection_scores[i],
                thresholds=np.linspace(.01, .99, 50),
                dir_path=detection_storage_dir[i],
            )
        if 'threshold' in hyper_params[i][event_classes[0]]:
            thresholds = {
                event_class: hyper_params[i][event_class]['threshold']
                for event_class in event_classes
            }
            event_detections.append(
                scores_to_event_list(
                    detection_scores[i],
                    thresholds,
                    event_classes=event_classes,
                ))
            if detection_storage_dir and detection_storage_dir[i]:
                io.write_detection(
                    detection_scores[i],
                    thresholds,
                    Path(detection_storage_dir[i]) / 'cbf.tsv',
                )
            if ground_truth and collar_based_params:
                f, p, r, stats = collar_based.fscore(
                    detection_scores[i],
                    ground_truth,
                    thresholds,
                    **collar_based_params,
                    return_onset_offset_dist_sum=True,
                    num_jobs=8,
                )
                print('f', f)
                print('p', p)
                print('r', r)
                for key in f:
                    results[-1].update({
                        f'{key}_f': f[key],
                        f'{key}_p': p[key],
                        f'{key}_r': r[key],
                    })
                    if key in stats:
                        results[-1].update({
                            f'{key}_onset_bias':
                            stats[key]['onset_dist_sum'] /
                            max(stats[key]['tps'], 1),
                            f'{key}_offset_bias':
                            stats[key]['offset_dist_sum'] /
                            max(stats[key]['tps'], 1),
                        })

            for clip_id in event_detections[-1]:
                events_in_clip = []
                for onset, offset, event_label in event_detections[-1][
                        clip_id]:
                    onset = max(
                        onset - pseudo_widening -
                        hyper_params[i][event_label].get('onset_bias', 0), 0)
                    offset = offset + pseudo_widening - hyper_params[i][
                        event_label].get('offset_bias', 0)
                    if offset > onset:
                        events_in_clip.append((onset, offset, event_label))
                event_detections[-1][clip_id] = events_in_clip
        else:
            event_detections.append(None)
        if ground_truth:
            if not isinstance(psds_params, (tuple, list)):
                psds_params = [psds_params]
            for j in range(len(psds_params)):
                psds, psd_roc, classwise_rocs = intersection_based.psds(
                    detection_scores[i],
                    ground_truth,
                    audio_durations,
                    **psds_params[j],
                    num_jobs=8,
                )
                print(f'psds[{j}]', psds)
                results[-1][f'psds[{j}]'] = psds
                for event_class, (tpr, efpr, *_) in classwise_rocs.items():
                    results[-1][f'{event_class}_auc[{j}]'] = staircase_auc(
                        tpr, efpr, psds_params[j].get('max_efpr', 100))
                if score_storage_dir and score_storage_dir[i] is not None:
                    psds, psd_roc, classwise_rocs = intersection_based.psds(
                        score_storage_dir[i],
                        ground_truth,
                        audio_durations,
                        **psds_params[j],
                        num_jobs=8,
                    )
                    print(f'psds[{j}] (from files)', psds)
                psds, psd_roc, classwise_rocs = intersection_based.reference.approximate_psds(
                    detection_scores[i],
                    ground_truth,
                    audio_durations,
                    **psds_params[j],
                    thresholds=np.linspace(.01, .99, 50),
                )
                print(f'approx_psds[{j}]', psds)
                results[-1][f'approx_psds[{j}]'] = psds
                for event_class, (tpr, efpr, *_) in classwise_rocs.items():
                    results[-1][
                        f'{event_class}_approx_auc[{j}]'] = staircase_auc(
                            tpr, efpr, psds_params[j].get('max_efpr', 100))
                if detection_storage_dir and detection_storage_dir[
                        i] is not None:
                    psds, psd_roc, classwise_rocs = intersection_based.reference.approximate_psds_from_detections_dir(
                        detection_storage_dir[i],
                        ground_truth,
                        audio_durations,
                        **psds_params[j],
                        thresholds=np.linspace(.01, .99, 50),
                    )
                    print(f'approx_psds[{j}] (from files)', psds)
    return event_detections, results
示例#6
0
文件: inference.py 项目: fgnt/pb_sed
def boundaries_detection(
    crnns,
    dataset,
    device,
    timestamps,
    event_classes,
    tags,
    hyper_params_dir,
    ground_truth,
    collar_based_params,
    max_segment_length=None,
    segment_overlap=None,
    pseudo_widening=.0,
):
    print()
    print('Boundaries Detection')
    hyper_params = load_json(
        Path(hyper_params_dir) / f'boundaries_detection_hyper_params_f.json')
    stepfilt_length = np.array([
        hyper_params[event_class]['stepfilt_length']
        for event_class in event_classes
    ])
    thresholds = {
        event_class: hyper_params[event_class]['threshold']
        for event_class in event_classes
    }
    boundary_scores = base.boundaries_detection(
        crnns,
        dataset,
        device,
        stepfilt_length=stepfilt_length,
        apply_mask=True,
        masks=tags,
        max_segment_length=max_segment_length,
        segment_overlap=segment_overlap,
        merge_score_segments=True,
        timestamps=timestamps,
        event_classes=event_classes,
    )
    results = {}
    if ground_truth:
        boundary_ground_truth = base.tuning.boundaries_from_events(
            ground_truth)
        f, p, r, stats = collar_based.fscore(
            boundary_scores,
            boundary_ground_truth,
            thresholds,
            **collar_based_params,
            return_onset_offset_dist_sum=True,
            num_jobs=8,
        )
        print('f', f)
        print('p', p)
        print('r', r)
        for key in f:
            results.update({
                f'{key}_f': f[key],
                f'{key}_p': p[key],
                f'{key}_r': r[key],
            })
            if key in stats:
                results.update({
                    f'{key}_onset_bias':
                    stats[key]['onset_dist_sum'] / max(stats[key]['tps'], 1),
                    f'{key}_offset_bias':
                    stats[key]['offset_dist_sum'] / max(stats[key]['tps'], 1),
                })

    boundaries_detection = scores_to_event_list(boundary_scores,
                                                thresholds,
                                                event_classes=event_classes)

    for clip_id in boundaries_detection:
        boundaries_in_clip = []
        for onset, offset, event_label in boundaries_detection[clip_id]:
            onset = max(
                np.round(
                    onset - pseudo_widening -
                    hyper_params[event_label]['onset_bias'], 3), 0)
            offset = np.round(
                offset + pseudo_widening -
                hyper_params[event_label]['offset_bias'], 3)
            if offset > onset:
                boundaries_in_clip.append((onset, offset, event_label))
        boundaries_detection[clip_id] = boundaries_in_clip
    return boundaries_detection, results