Пример #1
0
def tune_tagging(
    tagging_scores,
    medfilt_length_candidates,
    metrics,
    minimize=False,
    storage_dir=None,
):
    leaderboard = {}
    audio_ids = sorted(tagging_scores.keys())
    event_classes = None
    for medfilt_len in medfilt_length_candidates:
        if medfilt_len > 1:
            scores_filtered = deepcopy(tagging_scores)
            for audio_id in audio_ids:
                timestamps, event_classes = validate_score_dataframe(
                    tagging_scores[audio_id], event_classes=event_classes)
                scores = tagging_scores[audio_id][event_classes].to_numpy()
                scores_filtered[audio_id][event_classes] = medfilt(scores,
                                                                   medfilt_len,
                                                                   axis=0)
        else:
            scores_filtered = tagging_scores
        for metric_name, metric_fn in metrics.items():
            metric_values, other_values = metric_fn(scores_filtered)
            print()
            print(f'{metric_name}(medfilt_length={medfilt_len})')
            print(metric_values)
            hyper_params_and_other_values = {}
            for event_class in metric_values:
                if event_class.endswith('_average'):
                    continue
                hyper_params_and_other_values[event_class] = {
                    'medfilt_length': medfilt_len,
                    **other_values.get(event_class, {})
                }
            leaderboard = update_leaderboard(leaderboard,
                                             metric_name,
                                             metric_values,
                                             hyper_params_and_other_values,
                                             scores_filtered,
                                             minimize=minimize)
    if storage_dir is not None:
        for metric_name in leaderboard:
            metric_values = leaderboard[metric_name][0]
            hyper_params_and_other_values = leaderboard[metric_name][1]
            for event_class in hyper_params_and_other_values:
                hyper_params_and_other_values[event_class][
                    metric_name] = metric_values[event_class]
            dump_json(
                hyper_params_and_other_values,
                Path(storage_dir) / f'tagging_hyper_params_{metric_name}.json')
    print()
    print('best:')
    for metric_name in metrics:
        print()
        print(metric_name, leaderboard[metric_name][0])
    return leaderboard
Пример #2
0
def create_json(database_path: Path, json_path: Path, indent=4):
    database = construct_json(database_path)
    dump_json(
        database,
        json_path,
        create_path=True,
        indent=indent,
        ensure_ascii=False,
    )
Пример #3
0
def create_jsons(database_path: Path, json_path: Path, indent=4):
    assert database_path.is_dir(), (
        f'Path "{str(database_path.absolute())}" is not a directory.')
    database = construct_json(database_path)
    dump_json(
        database,
        json_path / 'desed.json',
        create_path=True,
        indent=indent,
        ensure_ascii=False,
    )
    print(f'Dumped json {json_path / "desed.json"}')
    database_pseudo_labeled = deepcopy(database)
    pseudo_labels_dir = pb_sed_root / 'exp' / 'strong_label_crnn_inference' / '2022-05-04-09-05-53'
    add_strong_labels(
        database_pseudo_labeled['datasets']['train_weak'],
        read_ground_truth_file(pseudo_labels_dir /
                               'train_weak_pseudo_labeled.tsv'))
    add_strong_labels(
        database_pseudo_labeled['datasets']['train_unlabel_in_domain'],
        read_ground_truth_file(pseudo_labels_dir /
                               'train_unlabel_in_domain_pseudo_labeled.tsv'))
    dump_json(
        database_pseudo_labeled,
        json_path / 'desed_pseudo_labeled_without_external.json',
        create_path=True,
        indent=indent,
        ensure_ascii=False,
    )
    print(
        f'Dumped json {json_path / "desed_pseudo_labeled_without_external.json"}'
    )
    database_pseudo_labeled = deepcopy(database)
    pseudo_labels_dir = pb_sed_root / 'exp' / 'strong_label_crnn_inference' / '2022-06-24-10-06-21'
    add_strong_labels(
        database_pseudo_labeled['datasets']['train_weak'],
        read_ground_truth_file(pseudo_labels_dir /
                               'train_weak_pseudo_labeled.tsv'))
    add_strong_labels(
        database_pseudo_labeled['datasets']['train_unlabel_in_domain'],
        read_ground_truth_file(pseudo_labels_dir /
                               'train_unlabel_in_domain_pseudo_labeled.tsv'))
    dump_json(
        database_pseudo_labeled,
        json_path / 'desed_pseudo_labeled_with_external.json',
        create_path=True,
        indent=indent,
        ensure_ascii=False,
    )
    print(
        f'Dumped json {json_path / "desed_pseudo_labeled_with_external.json"}')
Пример #4
0
def main(_run, storage_dir, debug, weak_label_crnn_hyper_params_dir,
         weak_label_crnn_dirs, weak_label_crnn_checkpoints,
         strong_label_crnn_dirs, strong_label_crnn_checkpoints, data_provider,
         validation_set_name, validation_ground_truth_filepath, eval_set_name,
         eval_ground_truth_filepath, medfilt_lengths, device):
    print()
    print('##### Tuning #####')
    print()
    print_config(_run)
    print(storage_dir)
    storage_dir = Path(storage_dir)

    if not isinstance(weak_label_crnn_checkpoints, list):
        assert isinstance(weak_label_crnn_checkpoints,
                          str), weak_label_crnn_checkpoints
        weak_label_crnn_checkpoints = len(weak_label_crnn_dirs) * [
            weak_label_crnn_checkpoints
        ]
    weak_label_crnns = [
        weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                         config_name='1/config.json',
                                         checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(weak_label_crnn_dirs,
                                             weak_label_crnn_checkpoints)
    ]
    data_provider = DESEDProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if validation_set_name == 'validation' and not validation_ground_truth_filepath:
        database_root = Path(
            data_provider.get_raw('validation')[0]
            ['audio_path']).parent.parent.parent.parent
        validation_ground_truth_filepath = database_root / 'metadata' / 'validation' / 'validation.tsv'
    elif validation_set_name == 'eval_public' and not validation_ground_truth_filepath:
        database_root = Path(
            data_provider.get_raw('eval_public')[0]
            ['audio_path']).parent.parent.parent.parent
        validation_ground_truth_filepath = database_root / 'metadata' / 'eval' / 'public.tsv'
    assert isinstance(
        validation_ground_truth_filepath,
        (str, Path)) and Path(validation_ground_truth_filepath).exists(
        ), validation_ground_truth_filepath

    dataset = data_provider.get_dataset(validation_set_name)
    audio_durations = {
        example['example_id']: example['audio_length']
        for example in data_provider.db.get_dataset(validation_set_name)
    }

    timestamps = {
        audio_id: np.array([0., audio_durations[audio_id]])
        for audio_id in audio_durations
    }
    tags, tagging_scores, _ = tagging(
        weak_label_crnns,
        dataset,
        device,
        timestamps,
        event_classes,
        weak_label_crnn_hyper_params_dir,
        None,
        None,
    )

    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }
    metrics = {
        'f':
        partial(
            base.f_collar,
            ground_truth=validation_ground_truth_filepath,
            return_onset_offset_bias=True,
            num_jobs=8,
            **collar_based_params,
        ),
        'auc1':
        partial(
            base.psd_auc,
            ground_truth=validation_ground_truth_filepath,
            audio_durations=audio_durations,
            num_jobs=8,
            **psds_scenario_1,
        ),
        'auc2':
        partial(
            base.psd_auc,
            ground_truth=validation_ground_truth_filepath,
            audio_durations=audio_durations,
            num_jobs=8,
            **psds_scenario_2,
        )
    }

    if not isinstance(strong_label_crnn_checkpoints, list):
        assert isinstance(strong_label_crnn_checkpoints,
                          str), strong_label_crnn_checkpoints
        strong_label_crnn_checkpoints = len(strong_label_crnn_dirs) * [
            strong_label_crnn_checkpoints
        ]
    strong_label_crnns = [
        strong_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                           config_name='1/config.json',
                                           checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(strong_label_crnn_dirs,
                                             strong_label_crnn_checkpoints)
    ]

    def add_tag_condition(example):
        example["tag_condition"] = np.array(
            [tags[example_id] for example_id in example["example_id"]])
        return example

    timestamps = np.arange(0, 10000) * frame_shift
    leaderboard = strong_label.crnn.tune_sound_event_detection(
        strong_label_crnns,
        dataset.map(add_tag_condition),
        device,
        timestamps,
        event_classes,
        tags,
        metrics,
        tag_masking={
            'f': True,
            'auc1': '?',
            'auc2': '?'
        },
        medfilt_lengths=medfilt_lengths,
    )
    dump_json(leaderboard['f'][1], storage_dir / f'sed_hyper_params_f.json')
    f, p, r, thresholds, _ = collar_based.best_fscore(
        scores=leaderboard['auc1'][2],
        ground_truth=validation_ground_truth_filepath,
        **collar_based_params,
        num_jobs=8)
    for event_class in thresholds:
        leaderboard['auc1'][1][event_class]['threshold'] = thresholds[
            event_class]
    dump_json(leaderboard['auc1'][1],
              storage_dir / 'sed_hyper_params_psds1.json')
    f, p, r, thresholds, _ = collar_based.best_fscore(
        scores=leaderboard['auc2'][2],
        ground_truth=validation_ground_truth_filepath,
        **collar_based_params,
        num_jobs=8)
    for event_class in thresholds:
        leaderboard['auc2'][1][event_class]['threshold'] = thresholds[
            event_class]
    dump_json(leaderboard['auc2'][1],
              storage_dir / 'sed_hyper_params_psds2.json')
    for crnn_dir in strong_label_crnn_dirs:
        tuning_dir = Path(crnn_dir) / 'hyper_params'
        os.makedirs(str(tuning_dir), exist_ok=True)
        (tuning_dir / storage_dir.name).symlink_to(storage_dir)
    print(storage_dir)

    if eval_set_name:
        evaluation.run(config_updates={
            'debug': debug,
            'strong_label_crnn_hyper_params_dir': str(storage_dir),
            'dataset_name': eval_set_name,
            'ground_truth_filepath': eval_ground_truth_filepath,
        }, )
Пример #5
0
def main(
    _run,
    storage_dir,
    hyper_params_dir,
    sed_hyper_params_name,
    crnn_dirs,
    crnn_checkpoints,
    device,
    data_provider,
    dataset_name,
    ground_truth_filepath,
    save_scores,
    save_detections,
    max_segment_length,
    segment_overlap,
    weak_pseudo_labeling,
    boundary_pseudo_labeling,
    strong_pseudo_labeling,
    pseudo_widening,
    pseudo_labeled_dataset_name,
):
    print()
    print('##### Inference #####')
    print()
    print_config(_run)
    print(storage_dir)
    emissions_tracker = EmissionsTracker(output_dir=storage_dir,
                                         on_csv_write="update",
                                         log_level='error')
    emissions_tracker.start()
    storage_dir = Path(storage_dir)

    boundary_collar_based_params = {
        'onset_collar': .5,
        'offset_collar': .5,
        'offset_collar_rate': .0,
    }
    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }

    if not isinstance(crnn_checkpoints, list):
        assert isinstance(crnn_checkpoints, str), crnn_checkpoints
        crnn_checkpoints = len(crnn_dirs) * [crnn_checkpoints]
    crnns = [
        CRNN.from_storage_dir(storage_dir=crnn_dir,
                              config_name='1/config.json',
                              checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(crnn_dirs, crnn_checkpoints)
    ]
    print('Params',
          sum([p.numel() for crnn in crnns for p in crnn.parameters()]))
    print(
        'CNN2d Params',
        sum([
            p.numel() for crnn in crnns for p in crnn.cnn.cnn_2d.parameters()
        ]))
    data_provider = DataProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if not isinstance(dataset_name, list):
        dataset_name = [dataset_name]
    if ground_truth_filepath is None:
        ground_truth_filepath = len(dataset_name) * [ground_truth_filepath]
    elif isinstance(ground_truth_filepath, (str, Path)):
        ground_truth_filepath = [ground_truth_filepath]
    assert len(ground_truth_filepath) == len(dataset_name)
    if not isinstance(weak_pseudo_labeling, list):
        weak_pseudo_labeling = len(dataset_name) * [weak_pseudo_labeling]
    assert len(weak_pseudo_labeling) == len(dataset_name)
    if not isinstance(boundary_pseudo_labeling, list):
        boundary_pseudo_labeling = len(dataset_name) * [
            boundary_pseudo_labeling
        ]
    assert len(boundary_pseudo_labeling) == len(dataset_name)
    if not isinstance(strong_pseudo_labeling, list):
        strong_pseudo_labeling = len(dataset_name) * [strong_pseudo_labeling]
    assert len(strong_pseudo_labeling) == len(dataset_name)
    if not isinstance(pseudo_labeled_dataset_name, list):
        pseudo_labeled_dataset_name = [pseudo_labeled_dataset_name]
    assert len(pseudo_labeled_dataset_name) == len(dataset_name)

    database = deepcopy(data_provider.db.data)
    for i in range(len(dataset_name)):
        print()
        print(dataset_name[i])

        if dataset_name[i] == 'eval_public' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('eval_public')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'eval' / 'public.tsv'
        elif dataset_name[i] == 'validation' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('validation')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'validation' / 'validation.tsv'

        dataset = data_provider.get_dataset(dataset_name[i])
        audio_durations = {
            example['example_id']: example['audio_length']
            for example in data_provider.db.get_dataset(dataset_name[i])
        }

        score_storage_dir = storage_dir / 'scores' / dataset_name[i]
        detection_storage_dir = storage_dir / 'detections' / dataset_name[i]

        if max_segment_length is None:
            timestamps = {
                audio_id: np.array([0., audio_durations[audio_id]])
                for audio_id in audio_durations
            }
        else:
            timestamps = {}
            for audio_id in audio_durations:
                ts = np.arange(0, audio_durations[audio_id],
                               (max_segment_length - segment_overlap) *
                               frame_shift)
                timestamps[audio_id] = np.concatenate(
                    (ts, [audio_durations[audio_id]]))
        tags, tagging_scores, tagging_results = tagging(
            crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            hyper_params_dir,
            ground_truth_filepath[i],
            audio_durations,
            [psds_scenario_1, psds_scenario_2],
            max_segment_length=max_segment_length,
            segment_overlap=segment_overlap,
        )
        if tagging_results:
            dump_json(tagging_results,
                      storage_dir / f'tagging_results_{dataset_name[i]}.json')

        timestamps = np.round(np.arange(0, 100000) * frame_shift, decimals=6)
        if ground_truth_filepath[i] is not None or boundary_pseudo_labeling[i]:
            boundaries, boundaries_detection_results = boundaries_detection(
                crnns,
                dataset,
                device,
                timestamps,
                event_classes,
                tags,
                hyper_params_dir,
                ground_truth_filepath[i],
                boundary_collar_based_params,
                max_segment_length=max_segment_length,
                segment_overlap=segment_overlap,
                pseudo_widening=pseudo_widening,
            )
            if boundaries_detection_results:
                dump_json(
                    boundaries_detection_results, storage_dir /
                    f'boundaries_detection_results_{dataset_name[i]}.json')
        else:
            boundaries = {}
        if not isinstance(sed_hyper_params_name, (list, tuple)):
            sed_hyper_params_name = [sed_hyper_params_name]
        if (ground_truth_filepath[i] is not None
            ) or strong_pseudo_labeling[i] or save_scores or save_detections:
            events, sed_results = sound_event_detection(
                crnns,
                dataset,
                device,
                timestamps,
                event_classes,
                tags,
                hyper_params_dir,
                sed_hyper_params_name,
                ground_truth_filepath[i],
                audio_durations,
                collar_based_params,
                [psds_scenario_1, psds_scenario_2],
                max_segment_length=max_segment_length,
                segment_overlap=segment_overlap,
                pseudo_widening=pseudo_widening,
                score_storage_dir=[
                    score_storage_dir / name for name in sed_hyper_params_name
                ] if save_scores else None,
                detection_storage_dir=[
                    detection_storage_dir / name
                    for name in sed_hyper_params_name
                ] if save_detections else None,
            )
            for j, sed_results_j in enumerate(sed_results):
                if sed_results_j:
                    dump_json(
                        sed_results_j, storage_dir /
                        f'sed_{sed_hyper_params_name[j]}_results_{dataset_name[i]}.json'
                    )
        else:
            events = [{}]
        database['datasets'][
            pseudo_labeled_dataset_name[i]] = base.pseudo_label(
                database['datasets'][dataset_name[i]],
                event_classes,
                weak_pseudo_labeling[i],
                boundary_pseudo_labeling[i],
                strong_pseudo_labeling[i],
                tags,
                boundaries,
                events[0],
            )

    if any(weak_pseudo_labeling) or any(boundary_pseudo_labeling) or any(
            strong_pseudo_labeling):
        dump_json(
            database,
            storage_dir / Path(data_provider.json_path).name,
            create_path=True,
            indent=4,
            ensure_ascii=False,
        )
    inference_dir = Path(hyper_params_dir) / 'inference'
    os.makedirs(str(inference_dir), exist_ok=True)
    (inference_dir / storage_dir.name).symlink_to(storage_dir)
    emissions_tracker.stop()
    print(storage_dir)
Пример #6
0
def tune_sound_event_detection(
    detection_scores,
    medfilt_length_candidates,
    tags,
    metrics,
    minimize=False,
    tag_masking=None,
    storage_dir=None,
):
    if tag_masking in [True, False, '?']:
        tag_masking = {key: tag_masking for key in metrics.keys()}
    assert isinstance(tag_masking, dict), tag_masking
    assert tag_masking.keys() == metrics.keys(), (tag_masking.keys(),
                                                  metrics.keys())
    assert all([val in [True, False, '?'] for val in tag_masking.values()])
    leaderboard = {}
    audio_ids = sorted(detection_scores.keys())
    event_classes = None
    for medfilt_len in medfilt_length_candidates:
        if medfilt_len > 1:
            scores_filtered = deepcopy(detection_scores)
            for audio_id in audio_ids:
                timestamps, event_classes = validate_score_dataframe(
                    detection_scores[audio_id], event_classes=event_classes)
                scores = detection_scores[audio_id][event_classes].to_numpy()
                scores_filtered[audio_id][event_classes] = medfilt(scores,
                                                                   medfilt_len,
                                                                   axis=0)
        else:
            scores_filtered = detection_scores
        scores_masked = deepcopy(scores_filtered)
        for audio_id in audio_ids:
            timestamps, event_classes = validate_score_dataframe(
                scores_masked[audio_id], event_classes=event_classes)
            scores_masked[audio_id][event_classes] *= tags[audio_id]
        for metric_name, metric_fn in metrics.items():
            if tag_masking[metric_name] == '?':
                tag_masking_candidates = [False, True]
            else:
                tag_masking_candidates = [tag_masking[metric_name]]
            for tag_masked in tag_masking_candidates:
                scores = scores_masked if tag_masked else scores_filtered
                metric_values, other_values = metric_fn(scores)
                print()
                print(
                    f'{metric_name}(medfilt_length={medfilt_len},tag_masked={tag_masked}):'
                )
                print(metric_values)
                hyper_params_and_other_values = {}
                for event_class in metric_values:
                    if event_class.endswith('_average'):
                        continue
                    hyper_params_and_other_values[event_class] = {
                        'medfilt_length': medfilt_len,
                        'tag_masked': tag_masked,
                        **other_values.get(event_class, {})
                    }
                leaderboard = update_leaderboard(leaderboard,
                                                 metric_name,
                                                 metric_values,
                                                 hyper_params_and_other_values,
                                                 scores,
                                                 minimize=minimize)
    for metric_name in leaderboard:
        metric_values = leaderboard[metric_name][0]
        hyper_params_and_other_values = leaderboard[metric_name][1]
        for event_class in hyper_params_and_other_values:
            hyper_params_and_other_values[event_class][
                metric_name] = metric_values[event_class]
        if storage_dir is not None:
            dump_json(
                hyper_params_and_other_values,
                Path(storage_dir) / f'sed_hyper_params_{metric_name}.json')
    print()
    print('best:')
    for metric_name in metrics:
        print()
        print(metric_name, ':')
        print(leaderboard[metric_name][0])
    return leaderboard
Пример #7
0
def main(_run, storage_dir, debug, crnn_dirs, crnn_checkpoints, data_provider,
         validation_set_name, validation_ground_truth_filepath, eval_set_name,
         eval_ground_truth_filepath, boundaries_filter_lengths,
         tune_detection_scenario_1, detection_window_lengths_scenario_1,
         detection_window_shift_scenario_1,
         detection_medfilt_lengths_scenario_1, tune_detection_scenario_2,
         detection_window_lengths_scenario_2,
         detection_window_shift_scenario_2,
         detection_medfilt_lengths_scenario_2, device):
    print()
    print('##### Tuning #####')
    print()
    print_config(_run)
    print(storage_dir)
    emissions_tracker = EmissionsTracker(output_dir=storage_dir,
                                         on_csv_write="update",
                                         log_level='error')
    emissions_tracker.start()
    storage_dir = Path(storage_dir)

    boundaries_collar_based_params = {
        'onset_collar': .5,
        'offset_collar': .5,
        'offset_collar_rate': .0,
        'min_precision': .8,
    }
    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }

    if not isinstance(crnn_checkpoints, list):
        assert isinstance(crnn_checkpoints, str), crnn_checkpoints
        crnn_checkpoints = len(crnn_dirs) * [crnn_checkpoints]
    crnns = [
        weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                         config_name='1/config.json',
                                         checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(crnn_dirs, crnn_checkpoints)
    ]
    data_provider = DataProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if validation_set_name == 'validation' and not validation_ground_truth_filepath:
        database_root = Path(
            data_provider.get_raw('validation')[0]
            ['audio_path']).parent.parent.parent.parent
        validation_ground_truth_filepath = database_root / 'metadata' / 'validation' / 'validation.tsv'
    elif validation_set_name == 'eval_public' and not validation_ground_truth_filepath:
        database_root = Path(
            data_provider.get_raw('eval_public')[0]
            ['audio_path']).parent.parent.parent.parent
        validation_ground_truth_filepath = database_root / 'metadata' / 'eval' / 'public.tsv'
    assert isinstance(
        validation_ground_truth_filepath,
        (str, Path)) and Path(validation_ground_truth_filepath).exists(
        ), validation_ground_truth_filepath

    dataset = data_provider.get_dataset(validation_set_name)
    audio_durations = {
        example['example_id']: example['audio_length']
        for example in data_provider.db.get_dataset(validation_set_name)
    }

    timestamps = {
        audio_id: np.array([0., audio_durations[audio_id]])
        for audio_id in audio_durations
    }
    metrics = {
        'f':
        partial(base.f_tag,
                ground_truth=validation_ground_truth_filepath,
                num_jobs=8)
    }
    leaderboard = weak_label.crnn.tune_tagging(crnns,
                                               dataset,
                                               device,
                                               timestamps,
                                               event_classes,
                                               metrics,
                                               storage_dir=storage_dir)
    _, hyper_params, tagging_scores = leaderboard['f']
    tagging_thresholds = np.array([
        hyper_params[event_class]['threshold'] for event_class in event_classes
    ])
    tags = {
        audio_id:
        tagging_scores[audio_id][event_classes].to_numpy() > tagging_thresholds
        for audio_id in tagging_scores
    }

    boundaries_ground_truth = base.boundaries_from_events(
        validation_ground_truth_filepath)
    timestamps = np.arange(0, 10000) * frame_shift
    metrics = {
        'f':
        partial(
            base.f_collar,
            ground_truth=boundaries_ground_truth,
            return_onset_offset_bias=True,
            num_jobs=8,
            **boundaries_collar_based_params,
        ),
    }
    weak_label.crnn.tune_boundary_detection(
        crnns,
        dataset,
        device,
        timestamps,
        event_classes,
        tags,
        metrics,
        tag_masking=True,
        stepfilt_lengths=boundaries_filter_lengths,
        storage_dir=storage_dir)

    if tune_detection_scenario_1:
        metrics = {
            'f':
            partial(
                base.f_collar,
                ground_truth=validation_ground_truth_filepath,
                return_onset_offset_bias=True,
                num_jobs=8,
                **collar_based_params,
            ),
            'auc':
            partial(
                base.psd_auc,
                ground_truth=validation_ground_truth_filepath,
                audio_durations=audio_durations,
                num_jobs=8,
                **psds_scenario_1,
            ),
        }
        leaderboard = weak_label.crnn.tune_sound_event_detection(
            crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            tags,
            metrics,
            tag_masking={
                'f': True,
                'auc': '?'
            },
            window_lengths=detection_window_lengths_scenario_1,
            window_shift=detection_window_shift_scenario_1,
            medfilt_lengths=detection_medfilt_lengths_scenario_1,
        )
        dump_json(leaderboard['f'][1],
                  storage_dir / f'sed_hyper_params_f.json')
        f, p, r, thresholds, _ = collar_based.best_fscore(
            scores=leaderboard['auc'][2],
            ground_truth=validation_ground_truth_filepath,
            **collar_based_params,
            num_jobs=8)
        for event_class in thresholds:
            leaderboard['auc'][1][event_class]['threshold'] = thresholds[
                event_class]
        dump_json(leaderboard['auc'][1],
                  storage_dir / 'sed_hyper_params_psds1.json')
    if tune_detection_scenario_2:
        metrics = {
            'auc':
            partial(
                base.psd_auc,
                ground_truth=validation_ground_truth_filepath,
                audio_durations=audio_durations,
                num_jobs=8,
                **psds_scenario_2,
            )
        }
        leaderboard = weak_label.crnn.tune_sound_event_detection(
            crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            tags,
            metrics,
            tag_masking=False,
            window_lengths=detection_window_lengths_scenario_2,
            window_shift=detection_window_shift_scenario_2,
            medfilt_lengths=detection_medfilt_lengths_scenario_2,
        )
        f, p, r, thresholds, _ = collar_based.best_fscore(
            scores=leaderboard['auc'][2],
            ground_truth=validation_ground_truth_filepath,
            **collar_based_params,
            num_jobs=8)
        for event_class in thresholds:
            leaderboard['auc'][1][event_class]['threshold'] = thresholds[
                event_class]
        dump_json(leaderboard['auc'][1],
                  storage_dir / 'sed_hyper_params_psds2.json')
    for crnn_dir in crnn_dirs:
        tuning_dir = Path(crnn_dir) / 'hyper_params'
        os.makedirs(str(tuning_dir), exist_ok=True)
        (tuning_dir / storage_dir.name).symlink_to(storage_dir)
    emissions_tracker.stop()
    print(storage_dir)

    if eval_set_name:
        if tune_detection_scenario_1:
            evaluation.run(config_updates={
                'debug': debug,
                'hyper_params_dir': str(storage_dir),
                'dataset_name': eval_set_name,
                'ground_truth_filepath': eval_ground_truth_filepath,
            }, )
        if tune_detection_scenario_2:
            evaluation.run(config_updates={
                'debug': debug,
                'hyper_params_dir': str(storage_dir),
                'dataset_name': eval_set_name,
                'ground_truth_filepath': eval_ground_truth_filepath,
                'sed_hyper_params_name': 'psds2',
            }, )
Пример #8
0
def main(
    _run,
    storage_dir,
    strong_label_crnn_hyper_params_dir,
    sed_hyper_params_name,
    strong_label_crnn_dirs,
    strong_label_crnn_checkpoints,
    weak_label_crnn_hyper_params_dir,
    weak_label_crnn_dirs,
    weak_label_crnn_checkpoints,
    device,
    data_provider,
    dataset_name,
    ground_truth_filepath,
    save_scores,
    save_detections,
    max_segment_length,
    segment_overlap,
    strong_pseudo_labeling,
    pseudo_widening,
    pseudo_labelled_dataset_name,
):
    print()
    print('##### Inference #####')
    print()
    print_config(_run)
    print(storage_dir)
    emissions_tracker = EmissionsTracker(output_dir=storage_dir,
                                         on_csv_write="update",
                                         log_level='error')
    emissions_tracker.start()
    storage_dir = Path(storage_dir)

    collar_based_params = {
        'onset_collar': .2,
        'offset_collar': .2,
        'offset_collar_rate': .2,
    }
    psds_scenario_1 = {
        'dtc_threshold': 0.7,
        'gtc_threshold': 0.7,
        'cttc_threshold': None,
        'alpha_ct': .0,
        'alpha_st': 1.,
    }
    psds_scenario_2 = {
        'dtc_threshold': 0.1,
        'gtc_threshold': 0.1,
        'cttc_threshold': 0.3,
        'alpha_ct': .5,
        'alpha_st': 1.,
    }

    if not isinstance(weak_label_crnn_checkpoints, list):
        assert isinstance(weak_label_crnn_checkpoints,
                          str), weak_label_crnn_checkpoints
        weak_label_crnn_checkpoints = len(weak_label_crnn_dirs) * [
            weak_label_crnn_checkpoints
        ]
    weak_label_crnns = [
        weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                         config_name='1/config.json',
                                         checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(weak_label_crnn_dirs,
                                             weak_label_crnn_checkpoints)
    ]
    print(
        'Weak Label CRNN Params',
        sum([
            p.numel() for crnn in weak_label_crnns for p in crnn.parameters()
        ]))
    print(
        'Weak Label CNN2d Params',
        sum([
            p.numel() for crnn in weak_label_crnns
            for p in crnn.cnn.cnn_2d.parameters()
        ]))
    if not isinstance(strong_label_crnn_checkpoints, list):
        assert isinstance(strong_label_crnn_checkpoints,
                          str), strong_label_crnn_checkpoints
        strong_label_crnn_checkpoints = len(strong_label_crnn_dirs) * [
            strong_label_crnn_checkpoints
        ]
    strong_label_crnns = [
        strong_label.CRNN.from_storage_dir(storage_dir=crnn_dir,
                                           config_name='1/config.json',
                                           checkpoint_name=crnn_checkpoint)
        for crnn_dir, crnn_checkpoint in zip(strong_label_crnn_dirs,
                                             strong_label_crnn_checkpoints)
    ]
    print(
        'Strong Label CRNN Params',
        sum([
            p.numel() for crnn in strong_label_crnns
            for p in crnn.parameters()
        ]))
    print(
        'Strong Label CNN2d Params',
        sum([
            p.numel() for crnn in strong_label_crnns
            for p in crnn.cnn.cnn_2d.parameters()
        ]))
    data_provider = DESEDProvider.from_config(data_provider)
    data_provider.test_transform.label_encoder.initialize_labels()
    event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping
    event_classes = [event_classes[i] for i in range(len(event_classes))]
    frame_shift = data_provider.test_transform.stft.shift
    frame_shift /= data_provider.audio_reader.target_sample_rate

    if not isinstance(dataset_name, list):
        dataset_name = [dataset_name]
    if ground_truth_filepath is None:
        ground_truth_filepath = len(dataset_name) * [ground_truth_filepath]
    elif isinstance(ground_truth_filepath, (str, Path)):
        ground_truth_filepath = [ground_truth_filepath]
    assert len(ground_truth_filepath) == len(dataset_name)
    if not isinstance(strong_pseudo_labeling, list):
        strong_pseudo_labeling = len(dataset_name) * [strong_pseudo_labeling]
    assert len(strong_pseudo_labeling) == len(dataset_name)
    if not isinstance(pseudo_labelled_dataset_name, list):
        pseudo_labelled_dataset_name = [pseudo_labelled_dataset_name]
    assert len(pseudo_labelled_dataset_name) == len(dataset_name)

    database = deepcopy(data_provider.db.data)
    for i in range(len(dataset_name)):
        print()
        print(dataset_name[i])
        if dataset_name[i] == 'eval_public' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('eval_public')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'eval' / 'public.tsv'
        elif dataset_name[i] == 'validation' and not ground_truth_filepath[i]:
            database_root = Path(
                data_provider.get_raw('validation')[0]
                ['audio_path']).parent.parent.parent.parent
            ground_truth_filepath[
                i] = database_root / 'metadata' / 'validation' / 'validation.tsv'

        dataset = data_provider.get_dataset(dataset_name[i])
        audio_durations = {
            example['example_id']: example['audio_length']
            for example in data_provider.db.get_dataset(dataset_name[i])
        }

        score_storage_dir = storage_dir / 'scores' / dataset_name[i]
        detection_storage_dir = storage_dir / 'detections' / dataset_name[i]

        if max_segment_length is None:
            timestamps = {
                audio_id: np.array([0., audio_durations[audio_id]])
                for audio_id in audio_durations
            }
        else:
            timestamps = {}
            for audio_id in audio_durations:
                ts = np.arange(
                    (2 + max_segment_length) * frame_shift,
                    audio_durations[audio_id],
                    (max_segment_length - segment_overlap) * frame_shift)
                timestamps[audio_id] = np.concatenate(
                    ([0.], ts - segment_overlap / 2 * frame_shift,
                     [audio_durations[audio_id]]))
        if max_segment_length is not None:
            dataset = dataset.map(
                partial(segment_batch,
                        max_length=max_segment_length,
                        overlap=segment_overlap)).unbatch()
        tags, tagging_scores, _ = tagging(
            weak_label_crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            weak_label_crnn_hyper_params_dir,
            None,
            None,
        )

        def add_tag_condition(example):
            example["tag_condition"] = np.array(
                [tags[example_id] for example_id in example["example_id"]])
            return example

        dataset = dataset.map(add_tag_condition)

        timestamps = np.round(np.arange(0, 100000) * frame_shift, decimals=6)
        if not isinstance(sed_hyper_params_name, (list, tuple)):
            sed_hyper_params_name = [sed_hyper_params_name]
        events, sed_results = sound_event_detection(
            strong_label_crnns,
            dataset,
            device,
            timestamps,
            event_classes,
            tags,
            strong_label_crnn_hyper_params_dir,
            sed_hyper_params_name,
            ground_truth_filepath[i],
            audio_durations,
            collar_based_params,
            [psds_scenario_1, psds_scenario_2],
            max_segment_length=max_segment_length,
            segment_overlap=segment_overlap,
            pseudo_widening=pseudo_widening,
            score_storage_dir=[
                score_storage_dir / name for name in sed_hyper_params_name
            ] if save_scores else None,
            detection_storage_dir=[
                detection_storage_dir / name for name in sed_hyper_params_name
            ] if save_detections else None,
        )
        for j, sed_results_j in enumerate(sed_results):
            if sed_results_j:
                dump_json(
                    sed_results_j, storage_dir /
                    f'sed_{sed_hyper_params_name[j]}_results_{dataset_name[i]}.json'
                )
        if strong_pseudo_labeling[i]:
            database['datasets'][
                pseudo_labelled_dataset_name[i]] = base.pseudo_label(
                    database['datasets'][dataset_name[i]],
                    event_classes,
                    False,
                    False,
                    strong_pseudo_labeling[i],
                    None,
                    None,
                    events[0],
                )
            with (storage_dir /
                  f'{dataset_name[i]}_pseudo_labeled.tsv').open('w') as fid:
                fid.write('filename\tonset\toffset\tevent_label\n')
                for key, event_list in events[0].items():
                    if len(event_list) == 0:
                        fid.write(f'{key}.wav\t\t\t\n')
                    for t_on, t_off, event_label in event_list:
                        fid.write(
                            f'{key}.wav\t{t_on}\t{t_off}\t{event_label}\n')

    if any(strong_pseudo_labeling):
        dump_json(
            database,
            storage_dir / Path(data_provider.json_path).name,
            create_path=True,
            indent=4,
            ensure_ascii=False,
        )
    inference_dir = Path(strong_label_crnn_hyper_params_dir) / 'inference'
    os.makedirs(str(inference_dir), exist_ok=True)
    (inference_dir / storage_dir.name).symlink_to(storage_dir)
    emissions_tracker.stop()
    print(storage_dir)