def tune_tagging( tagging_scores, medfilt_length_candidates, metrics, minimize=False, storage_dir=None, ): leaderboard = {} audio_ids = sorted(tagging_scores.keys()) event_classes = None for medfilt_len in medfilt_length_candidates: if medfilt_len > 1: scores_filtered = deepcopy(tagging_scores) for audio_id in audio_ids: timestamps, event_classes = validate_score_dataframe( tagging_scores[audio_id], event_classes=event_classes) scores = tagging_scores[audio_id][event_classes].to_numpy() scores_filtered[audio_id][event_classes] = medfilt(scores, medfilt_len, axis=0) else: scores_filtered = tagging_scores for metric_name, metric_fn in metrics.items(): metric_values, other_values = metric_fn(scores_filtered) print() print(f'{metric_name}(medfilt_length={medfilt_len})') print(metric_values) hyper_params_and_other_values = {} for event_class in metric_values: if event_class.endswith('_average'): continue hyper_params_and_other_values[event_class] = { 'medfilt_length': medfilt_len, **other_values.get(event_class, {}) } leaderboard = update_leaderboard(leaderboard, metric_name, metric_values, hyper_params_and_other_values, scores_filtered, minimize=minimize) if storage_dir is not None: for metric_name in leaderboard: metric_values = leaderboard[metric_name][0] hyper_params_and_other_values = leaderboard[metric_name][1] for event_class in hyper_params_and_other_values: hyper_params_and_other_values[event_class][ metric_name] = metric_values[event_class] dump_json( hyper_params_and_other_values, Path(storage_dir) / f'tagging_hyper_params_{metric_name}.json') print() print('best:') for metric_name in metrics: print() print(metric_name, leaderboard[metric_name][0]) return leaderboard
def create_json(database_path: Path, json_path: Path, indent=4): database = construct_json(database_path) dump_json( database, json_path, create_path=True, indent=indent, ensure_ascii=False, )
def create_jsons(database_path: Path, json_path: Path, indent=4): assert database_path.is_dir(), ( f'Path "{str(database_path.absolute())}" is not a directory.') database = construct_json(database_path) dump_json( database, json_path / 'desed.json', create_path=True, indent=indent, ensure_ascii=False, ) print(f'Dumped json {json_path / "desed.json"}') database_pseudo_labeled = deepcopy(database) pseudo_labels_dir = pb_sed_root / 'exp' / 'strong_label_crnn_inference' / '2022-05-04-09-05-53' add_strong_labels( database_pseudo_labeled['datasets']['train_weak'], read_ground_truth_file(pseudo_labels_dir / 'train_weak_pseudo_labeled.tsv')) add_strong_labels( database_pseudo_labeled['datasets']['train_unlabel_in_domain'], read_ground_truth_file(pseudo_labels_dir / 'train_unlabel_in_domain_pseudo_labeled.tsv')) dump_json( database_pseudo_labeled, json_path / 'desed_pseudo_labeled_without_external.json', create_path=True, indent=indent, ensure_ascii=False, ) print( f'Dumped json {json_path / "desed_pseudo_labeled_without_external.json"}' ) database_pseudo_labeled = deepcopy(database) pseudo_labels_dir = pb_sed_root / 'exp' / 'strong_label_crnn_inference' / '2022-06-24-10-06-21' add_strong_labels( database_pseudo_labeled['datasets']['train_weak'], read_ground_truth_file(pseudo_labels_dir / 'train_weak_pseudo_labeled.tsv')) add_strong_labels( database_pseudo_labeled['datasets']['train_unlabel_in_domain'], read_ground_truth_file(pseudo_labels_dir / 'train_unlabel_in_domain_pseudo_labeled.tsv')) dump_json( database_pseudo_labeled, json_path / 'desed_pseudo_labeled_with_external.json', create_path=True, indent=indent, ensure_ascii=False, ) print( f'Dumped json {json_path / "desed_pseudo_labeled_with_external.json"}')
def main(_run, storage_dir, debug, weak_label_crnn_hyper_params_dir, weak_label_crnn_dirs, weak_label_crnn_checkpoints, strong_label_crnn_dirs, strong_label_crnn_checkpoints, data_provider, validation_set_name, validation_ground_truth_filepath, eval_set_name, eval_ground_truth_filepath, medfilt_lengths, device): print() print('##### Tuning #####') print() print_config(_run) print(storage_dir) storage_dir = Path(storage_dir) if not isinstance(weak_label_crnn_checkpoints, list): assert isinstance(weak_label_crnn_checkpoints, str), weak_label_crnn_checkpoints weak_label_crnn_checkpoints = len(weak_label_crnn_dirs) * [ weak_label_crnn_checkpoints ] weak_label_crnns = [ weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir, config_name='1/config.json', checkpoint_name=crnn_checkpoint) for crnn_dir, crnn_checkpoint in zip(weak_label_crnn_dirs, weak_label_crnn_checkpoints) ] data_provider = DESEDProvider.from_config(data_provider) data_provider.test_transform.label_encoder.initialize_labels() event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping event_classes = [event_classes[i] for i in range(len(event_classes))] frame_shift = data_provider.test_transform.stft.shift frame_shift /= data_provider.audio_reader.target_sample_rate if validation_set_name == 'validation' and not validation_ground_truth_filepath: database_root = Path( data_provider.get_raw('validation')[0] ['audio_path']).parent.parent.parent.parent validation_ground_truth_filepath = database_root / 'metadata' / 'validation' / 'validation.tsv' elif validation_set_name == 'eval_public' and not validation_ground_truth_filepath: database_root = Path( data_provider.get_raw('eval_public')[0] ['audio_path']).parent.parent.parent.parent validation_ground_truth_filepath = database_root / 'metadata' / 'eval' / 'public.tsv' assert isinstance( validation_ground_truth_filepath, (str, Path)) and Path(validation_ground_truth_filepath).exists( ), validation_ground_truth_filepath dataset = data_provider.get_dataset(validation_set_name) audio_durations = { example['example_id']: example['audio_length'] for example in data_provider.db.get_dataset(validation_set_name) } timestamps = { audio_id: np.array([0., audio_durations[audio_id]]) for audio_id in audio_durations } tags, tagging_scores, _ = tagging( weak_label_crnns, dataset, device, timestamps, event_classes, weak_label_crnn_hyper_params_dir, None, None, ) collar_based_params = { 'onset_collar': .2, 'offset_collar': .2, 'offset_collar_rate': .2, } psds_scenario_1 = { 'dtc_threshold': 0.7, 'gtc_threshold': 0.7, 'cttc_threshold': None, 'alpha_ct': .0, 'alpha_st': 1., } psds_scenario_2 = { 'dtc_threshold': 0.1, 'gtc_threshold': 0.1, 'cttc_threshold': 0.3, 'alpha_ct': .5, 'alpha_st': 1., } metrics = { 'f': partial( base.f_collar, ground_truth=validation_ground_truth_filepath, return_onset_offset_bias=True, num_jobs=8, **collar_based_params, ), 'auc1': partial( base.psd_auc, ground_truth=validation_ground_truth_filepath, audio_durations=audio_durations, num_jobs=8, **psds_scenario_1, ), 'auc2': partial( base.psd_auc, ground_truth=validation_ground_truth_filepath, audio_durations=audio_durations, num_jobs=8, **psds_scenario_2, ) } if not isinstance(strong_label_crnn_checkpoints, list): assert isinstance(strong_label_crnn_checkpoints, str), strong_label_crnn_checkpoints strong_label_crnn_checkpoints = len(strong_label_crnn_dirs) * [ strong_label_crnn_checkpoints ] strong_label_crnns = [ strong_label.CRNN.from_storage_dir(storage_dir=crnn_dir, config_name='1/config.json', checkpoint_name=crnn_checkpoint) for crnn_dir, crnn_checkpoint in zip(strong_label_crnn_dirs, strong_label_crnn_checkpoints) ] def add_tag_condition(example): example["tag_condition"] = np.array( [tags[example_id] for example_id in example["example_id"]]) return example timestamps = np.arange(0, 10000) * frame_shift leaderboard = strong_label.crnn.tune_sound_event_detection( strong_label_crnns, dataset.map(add_tag_condition), device, timestamps, event_classes, tags, metrics, tag_masking={ 'f': True, 'auc1': '?', 'auc2': '?' }, medfilt_lengths=medfilt_lengths, ) dump_json(leaderboard['f'][1], storage_dir / f'sed_hyper_params_f.json') f, p, r, thresholds, _ = collar_based.best_fscore( scores=leaderboard['auc1'][2], ground_truth=validation_ground_truth_filepath, **collar_based_params, num_jobs=8) for event_class in thresholds: leaderboard['auc1'][1][event_class]['threshold'] = thresholds[ event_class] dump_json(leaderboard['auc1'][1], storage_dir / 'sed_hyper_params_psds1.json') f, p, r, thresholds, _ = collar_based.best_fscore( scores=leaderboard['auc2'][2], ground_truth=validation_ground_truth_filepath, **collar_based_params, num_jobs=8) for event_class in thresholds: leaderboard['auc2'][1][event_class]['threshold'] = thresholds[ event_class] dump_json(leaderboard['auc2'][1], storage_dir / 'sed_hyper_params_psds2.json') for crnn_dir in strong_label_crnn_dirs: tuning_dir = Path(crnn_dir) / 'hyper_params' os.makedirs(str(tuning_dir), exist_ok=True) (tuning_dir / storage_dir.name).symlink_to(storage_dir) print(storage_dir) if eval_set_name: evaluation.run(config_updates={ 'debug': debug, 'strong_label_crnn_hyper_params_dir': str(storage_dir), 'dataset_name': eval_set_name, 'ground_truth_filepath': eval_ground_truth_filepath, }, )
def main( _run, storage_dir, hyper_params_dir, sed_hyper_params_name, crnn_dirs, crnn_checkpoints, device, data_provider, dataset_name, ground_truth_filepath, save_scores, save_detections, max_segment_length, segment_overlap, weak_pseudo_labeling, boundary_pseudo_labeling, strong_pseudo_labeling, pseudo_widening, pseudo_labeled_dataset_name, ): print() print('##### Inference #####') print() print_config(_run) print(storage_dir) emissions_tracker = EmissionsTracker(output_dir=storage_dir, on_csv_write="update", log_level='error') emissions_tracker.start() storage_dir = Path(storage_dir) boundary_collar_based_params = { 'onset_collar': .5, 'offset_collar': .5, 'offset_collar_rate': .0, } collar_based_params = { 'onset_collar': .2, 'offset_collar': .2, 'offset_collar_rate': .2, } psds_scenario_1 = { 'dtc_threshold': 0.7, 'gtc_threshold': 0.7, 'cttc_threshold': None, 'alpha_ct': .0, 'alpha_st': 1., } psds_scenario_2 = { 'dtc_threshold': 0.1, 'gtc_threshold': 0.1, 'cttc_threshold': 0.3, 'alpha_ct': .5, 'alpha_st': 1., } if not isinstance(crnn_checkpoints, list): assert isinstance(crnn_checkpoints, str), crnn_checkpoints crnn_checkpoints = len(crnn_dirs) * [crnn_checkpoints] crnns = [ CRNN.from_storage_dir(storage_dir=crnn_dir, config_name='1/config.json', checkpoint_name=crnn_checkpoint) for crnn_dir, crnn_checkpoint in zip(crnn_dirs, crnn_checkpoints) ] print('Params', sum([p.numel() for crnn in crnns for p in crnn.parameters()])) print( 'CNN2d Params', sum([ p.numel() for crnn in crnns for p in crnn.cnn.cnn_2d.parameters() ])) data_provider = DataProvider.from_config(data_provider) data_provider.test_transform.label_encoder.initialize_labels() event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping event_classes = [event_classes[i] for i in range(len(event_classes))] frame_shift = data_provider.test_transform.stft.shift frame_shift /= data_provider.audio_reader.target_sample_rate if not isinstance(dataset_name, list): dataset_name = [dataset_name] if ground_truth_filepath is None: ground_truth_filepath = len(dataset_name) * [ground_truth_filepath] elif isinstance(ground_truth_filepath, (str, Path)): ground_truth_filepath = [ground_truth_filepath] assert len(ground_truth_filepath) == len(dataset_name) if not isinstance(weak_pseudo_labeling, list): weak_pseudo_labeling = len(dataset_name) * [weak_pseudo_labeling] assert len(weak_pseudo_labeling) == len(dataset_name) if not isinstance(boundary_pseudo_labeling, list): boundary_pseudo_labeling = len(dataset_name) * [ boundary_pseudo_labeling ] assert len(boundary_pseudo_labeling) == len(dataset_name) if not isinstance(strong_pseudo_labeling, list): strong_pseudo_labeling = len(dataset_name) * [strong_pseudo_labeling] assert len(strong_pseudo_labeling) == len(dataset_name) if not isinstance(pseudo_labeled_dataset_name, list): pseudo_labeled_dataset_name = [pseudo_labeled_dataset_name] assert len(pseudo_labeled_dataset_name) == len(dataset_name) database = deepcopy(data_provider.db.data) for i in range(len(dataset_name)): print() print(dataset_name[i]) if dataset_name[i] == 'eval_public' and not ground_truth_filepath[i]: database_root = Path( data_provider.get_raw('eval_public')[0] ['audio_path']).parent.parent.parent.parent ground_truth_filepath[ i] = database_root / 'metadata' / 'eval' / 'public.tsv' elif dataset_name[i] == 'validation' and not ground_truth_filepath[i]: database_root = Path( data_provider.get_raw('validation')[0] ['audio_path']).parent.parent.parent.parent ground_truth_filepath[ i] = database_root / 'metadata' / 'validation' / 'validation.tsv' dataset = data_provider.get_dataset(dataset_name[i]) audio_durations = { example['example_id']: example['audio_length'] for example in data_provider.db.get_dataset(dataset_name[i]) } score_storage_dir = storage_dir / 'scores' / dataset_name[i] detection_storage_dir = storage_dir / 'detections' / dataset_name[i] if max_segment_length is None: timestamps = { audio_id: np.array([0., audio_durations[audio_id]]) for audio_id in audio_durations } else: timestamps = {} for audio_id in audio_durations: ts = np.arange(0, audio_durations[audio_id], (max_segment_length - segment_overlap) * frame_shift) timestamps[audio_id] = np.concatenate( (ts, [audio_durations[audio_id]])) tags, tagging_scores, tagging_results = tagging( crnns, dataset, device, timestamps, event_classes, hyper_params_dir, ground_truth_filepath[i], audio_durations, [psds_scenario_1, psds_scenario_2], max_segment_length=max_segment_length, segment_overlap=segment_overlap, ) if tagging_results: dump_json(tagging_results, storage_dir / f'tagging_results_{dataset_name[i]}.json') timestamps = np.round(np.arange(0, 100000) * frame_shift, decimals=6) if ground_truth_filepath[i] is not None or boundary_pseudo_labeling[i]: boundaries, boundaries_detection_results = boundaries_detection( crnns, dataset, device, timestamps, event_classes, tags, hyper_params_dir, ground_truth_filepath[i], boundary_collar_based_params, max_segment_length=max_segment_length, segment_overlap=segment_overlap, pseudo_widening=pseudo_widening, ) if boundaries_detection_results: dump_json( boundaries_detection_results, storage_dir / f'boundaries_detection_results_{dataset_name[i]}.json') else: boundaries = {} if not isinstance(sed_hyper_params_name, (list, tuple)): sed_hyper_params_name = [sed_hyper_params_name] if (ground_truth_filepath[i] is not None ) or strong_pseudo_labeling[i] or save_scores or save_detections: events, sed_results = sound_event_detection( crnns, dataset, device, timestamps, event_classes, tags, hyper_params_dir, sed_hyper_params_name, ground_truth_filepath[i], audio_durations, collar_based_params, [psds_scenario_1, psds_scenario_2], max_segment_length=max_segment_length, segment_overlap=segment_overlap, pseudo_widening=pseudo_widening, score_storage_dir=[ score_storage_dir / name for name in sed_hyper_params_name ] if save_scores else None, detection_storage_dir=[ detection_storage_dir / name for name in sed_hyper_params_name ] if save_detections else None, ) for j, sed_results_j in enumerate(sed_results): if sed_results_j: dump_json( sed_results_j, storage_dir / f'sed_{sed_hyper_params_name[j]}_results_{dataset_name[i]}.json' ) else: events = [{}] database['datasets'][ pseudo_labeled_dataset_name[i]] = base.pseudo_label( database['datasets'][dataset_name[i]], event_classes, weak_pseudo_labeling[i], boundary_pseudo_labeling[i], strong_pseudo_labeling[i], tags, boundaries, events[0], ) if any(weak_pseudo_labeling) or any(boundary_pseudo_labeling) or any( strong_pseudo_labeling): dump_json( database, storage_dir / Path(data_provider.json_path).name, create_path=True, indent=4, ensure_ascii=False, ) inference_dir = Path(hyper_params_dir) / 'inference' os.makedirs(str(inference_dir), exist_ok=True) (inference_dir / storage_dir.name).symlink_to(storage_dir) emissions_tracker.stop() print(storage_dir)
def tune_sound_event_detection( detection_scores, medfilt_length_candidates, tags, metrics, minimize=False, tag_masking=None, storage_dir=None, ): if tag_masking in [True, False, '?']: tag_masking = {key: tag_masking for key in metrics.keys()} assert isinstance(tag_masking, dict), tag_masking assert tag_masking.keys() == metrics.keys(), (tag_masking.keys(), metrics.keys()) assert all([val in [True, False, '?'] for val in tag_masking.values()]) leaderboard = {} audio_ids = sorted(detection_scores.keys()) event_classes = None for medfilt_len in medfilt_length_candidates: if medfilt_len > 1: scores_filtered = deepcopy(detection_scores) for audio_id in audio_ids: timestamps, event_classes = validate_score_dataframe( detection_scores[audio_id], event_classes=event_classes) scores = detection_scores[audio_id][event_classes].to_numpy() scores_filtered[audio_id][event_classes] = medfilt(scores, medfilt_len, axis=0) else: scores_filtered = detection_scores scores_masked = deepcopy(scores_filtered) for audio_id in audio_ids: timestamps, event_classes = validate_score_dataframe( scores_masked[audio_id], event_classes=event_classes) scores_masked[audio_id][event_classes] *= tags[audio_id] for metric_name, metric_fn in metrics.items(): if tag_masking[metric_name] == '?': tag_masking_candidates = [False, True] else: tag_masking_candidates = [tag_masking[metric_name]] for tag_masked in tag_masking_candidates: scores = scores_masked if tag_masked else scores_filtered metric_values, other_values = metric_fn(scores) print() print( f'{metric_name}(medfilt_length={medfilt_len},tag_masked={tag_masked}):' ) print(metric_values) hyper_params_and_other_values = {} for event_class in metric_values: if event_class.endswith('_average'): continue hyper_params_and_other_values[event_class] = { 'medfilt_length': medfilt_len, 'tag_masked': tag_masked, **other_values.get(event_class, {}) } leaderboard = update_leaderboard(leaderboard, metric_name, metric_values, hyper_params_and_other_values, scores, minimize=minimize) for metric_name in leaderboard: metric_values = leaderboard[metric_name][0] hyper_params_and_other_values = leaderboard[metric_name][1] for event_class in hyper_params_and_other_values: hyper_params_and_other_values[event_class][ metric_name] = metric_values[event_class] if storage_dir is not None: dump_json( hyper_params_and_other_values, Path(storage_dir) / f'sed_hyper_params_{metric_name}.json') print() print('best:') for metric_name in metrics: print() print(metric_name, ':') print(leaderboard[metric_name][0]) return leaderboard
def main(_run, storage_dir, debug, crnn_dirs, crnn_checkpoints, data_provider, validation_set_name, validation_ground_truth_filepath, eval_set_name, eval_ground_truth_filepath, boundaries_filter_lengths, tune_detection_scenario_1, detection_window_lengths_scenario_1, detection_window_shift_scenario_1, detection_medfilt_lengths_scenario_1, tune_detection_scenario_2, detection_window_lengths_scenario_2, detection_window_shift_scenario_2, detection_medfilt_lengths_scenario_2, device): print() print('##### Tuning #####') print() print_config(_run) print(storage_dir) emissions_tracker = EmissionsTracker(output_dir=storage_dir, on_csv_write="update", log_level='error') emissions_tracker.start() storage_dir = Path(storage_dir) boundaries_collar_based_params = { 'onset_collar': .5, 'offset_collar': .5, 'offset_collar_rate': .0, 'min_precision': .8, } collar_based_params = { 'onset_collar': .2, 'offset_collar': .2, 'offset_collar_rate': .2, } psds_scenario_1 = { 'dtc_threshold': 0.7, 'gtc_threshold': 0.7, 'cttc_threshold': None, 'alpha_ct': .0, 'alpha_st': 1., } psds_scenario_2 = { 'dtc_threshold': 0.1, 'gtc_threshold': 0.1, 'cttc_threshold': 0.3, 'alpha_ct': .5, 'alpha_st': 1., } if not isinstance(crnn_checkpoints, list): assert isinstance(crnn_checkpoints, str), crnn_checkpoints crnn_checkpoints = len(crnn_dirs) * [crnn_checkpoints] crnns = [ weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir, config_name='1/config.json', checkpoint_name=crnn_checkpoint) for crnn_dir, crnn_checkpoint in zip(crnn_dirs, crnn_checkpoints) ] data_provider = DataProvider.from_config(data_provider) data_provider.test_transform.label_encoder.initialize_labels() event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping event_classes = [event_classes[i] for i in range(len(event_classes))] frame_shift = data_provider.test_transform.stft.shift frame_shift /= data_provider.audio_reader.target_sample_rate if validation_set_name == 'validation' and not validation_ground_truth_filepath: database_root = Path( data_provider.get_raw('validation')[0] ['audio_path']).parent.parent.parent.parent validation_ground_truth_filepath = database_root / 'metadata' / 'validation' / 'validation.tsv' elif validation_set_name == 'eval_public' and not validation_ground_truth_filepath: database_root = Path( data_provider.get_raw('eval_public')[0] ['audio_path']).parent.parent.parent.parent validation_ground_truth_filepath = database_root / 'metadata' / 'eval' / 'public.tsv' assert isinstance( validation_ground_truth_filepath, (str, Path)) and Path(validation_ground_truth_filepath).exists( ), validation_ground_truth_filepath dataset = data_provider.get_dataset(validation_set_name) audio_durations = { example['example_id']: example['audio_length'] for example in data_provider.db.get_dataset(validation_set_name) } timestamps = { audio_id: np.array([0., audio_durations[audio_id]]) for audio_id in audio_durations } metrics = { 'f': partial(base.f_tag, ground_truth=validation_ground_truth_filepath, num_jobs=8) } leaderboard = weak_label.crnn.tune_tagging(crnns, dataset, device, timestamps, event_classes, metrics, storage_dir=storage_dir) _, hyper_params, tagging_scores = leaderboard['f'] tagging_thresholds = np.array([ hyper_params[event_class]['threshold'] for event_class in event_classes ]) tags = { audio_id: tagging_scores[audio_id][event_classes].to_numpy() > tagging_thresholds for audio_id in tagging_scores } boundaries_ground_truth = base.boundaries_from_events( validation_ground_truth_filepath) timestamps = np.arange(0, 10000) * frame_shift metrics = { 'f': partial( base.f_collar, ground_truth=boundaries_ground_truth, return_onset_offset_bias=True, num_jobs=8, **boundaries_collar_based_params, ), } weak_label.crnn.tune_boundary_detection( crnns, dataset, device, timestamps, event_classes, tags, metrics, tag_masking=True, stepfilt_lengths=boundaries_filter_lengths, storage_dir=storage_dir) if tune_detection_scenario_1: metrics = { 'f': partial( base.f_collar, ground_truth=validation_ground_truth_filepath, return_onset_offset_bias=True, num_jobs=8, **collar_based_params, ), 'auc': partial( base.psd_auc, ground_truth=validation_ground_truth_filepath, audio_durations=audio_durations, num_jobs=8, **psds_scenario_1, ), } leaderboard = weak_label.crnn.tune_sound_event_detection( crnns, dataset, device, timestamps, event_classes, tags, metrics, tag_masking={ 'f': True, 'auc': '?' }, window_lengths=detection_window_lengths_scenario_1, window_shift=detection_window_shift_scenario_1, medfilt_lengths=detection_medfilt_lengths_scenario_1, ) dump_json(leaderboard['f'][1], storage_dir / f'sed_hyper_params_f.json') f, p, r, thresholds, _ = collar_based.best_fscore( scores=leaderboard['auc'][2], ground_truth=validation_ground_truth_filepath, **collar_based_params, num_jobs=8) for event_class in thresholds: leaderboard['auc'][1][event_class]['threshold'] = thresholds[ event_class] dump_json(leaderboard['auc'][1], storage_dir / 'sed_hyper_params_psds1.json') if tune_detection_scenario_2: metrics = { 'auc': partial( base.psd_auc, ground_truth=validation_ground_truth_filepath, audio_durations=audio_durations, num_jobs=8, **psds_scenario_2, ) } leaderboard = weak_label.crnn.tune_sound_event_detection( crnns, dataset, device, timestamps, event_classes, tags, metrics, tag_masking=False, window_lengths=detection_window_lengths_scenario_2, window_shift=detection_window_shift_scenario_2, medfilt_lengths=detection_medfilt_lengths_scenario_2, ) f, p, r, thresholds, _ = collar_based.best_fscore( scores=leaderboard['auc'][2], ground_truth=validation_ground_truth_filepath, **collar_based_params, num_jobs=8) for event_class in thresholds: leaderboard['auc'][1][event_class]['threshold'] = thresholds[ event_class] dump_json(leaderboard['auc'][1], storage_dir / 'sed_hyper_params_psds2.json') for crnn_dir in crnn_dirs: tuning_dir = Path(crnn_dir) / 'hyper_params' os.makedirs(str(tuning_dir), exist_ok=True) (tuning_dir / storage_dir.name).symlink_to(storage_dir) emissions_tracker.stop() print(storage_dir) if eval_set_name: if tune_detection_scenario_1: evaluation.run(config_updates={ 'debug': debug, 'hyper_params_dir': str(storage_dir), 'dataset_name': eval_set_name, 'ground_truth_filepath': eval_ground_truth_filepath, }, ) if tune_detection_scenario_2: evaluation.run(config_updates={ 'debug': debug, 'hyper_params_dir': str(storage_dir), 'dataset_name': eval_set_name, 'ground_truth_filepath': eval_ground_truth_filepath, 'sed_hyper_params_name': 'psds2', }, )
def main( _run, storage_dir, strong_label_crnn_hyper_params_dir, sed_hyper_params_name, strong_label_crnn_dirs, strong_label_crnn_checkpoints, weak_label_crnn_hyper_params_dir, weak_label_crnn_dirs, weak_label_crnn_checkpoints, device, data_provider, dataset_name, ground_truth_filepath, save_scores, save_detections, max_segment_length, segment_overlap, strong_pseudo_labeling, pseudo_widening, pseudo_labelled_dataset_name, ): print() print('##### Inference #####') print() print_config(_run) print(storage_dir) emissions_tracker = EmissionsTracker(output_dir=storage_dir, on_csv_write="update", log_level='error') emissions_tracker.start() storage_dir = Path(storage_dir) collar_based_params = { 'onset_collar': .2, 'offset_collar': .2, 'offset_collar_rate': .2, } psds_scenario_1 = { 'dtc_threshold': 0.7, 'gtc_threshold': 0.7, 'cttc_threshold': None, 'alpha_ct': .0, 'alpha_st': 1., } psds_scenario_2 = { 'dtc_threshold': 0.1, 'gtc_threshold': 0.1, 'cttc_threshold': 0.3, 'alpha_ct': .5, 'alpha_st': 1., } if not isinstance(weak_label_crnn_checkpoints, list): assert isinstance(weak_label_crnn_checkpoints, str), weak_label_crnn_checkpoints weak_label_crnn_checkpoints = len(weak_label_crnn_dirs) * [ weak_label_crnn_checkpoints ] weak_label_crnns = [ weak_label.CRNN.from_storage_dir(storage_dir=crnn_dir, config_name='1/config.json', checkpoint_name=crnn_checkpoint) for crnn_dir, crnn_checkpoint in zip(weak_label_crnn_dirs, weak_label_crnn_checkpoints) ] print( 'Weak Label CRNN Params', sum([ p.numel() for crnn in weak_label_crnns for p in crnn.parameters() ])) print( 'Weak Label CNN2d Params', sum([ p.numel() for crnn in weak_label_crnns for p in crnn.cnn.cnn_2d.parameters() ])) if not isinstance(strong_label_crnn_checkpoints, list): assert isinstance(strong_label_crnn_checkpoints, str), strong_label_crnn_checkpoints strong_label_crnn_checkpoints = len(strong_label_crnn_dirs) * [ strong_label_crnn_checkpoints ] strong_label_crnns = [ strong_label.CRNN.from_storage_dir(storage_dir=crnn_dir, config_name='1/config.json', checkpoint_name=crnn_checkpoint) for crnn_dir, crnn_checkpoint in zip(strong_label_crnn_dirs, strong_label_crnn_checkpoints) ] print( 'Strong Label CRNN Params', sum([ p.numel() for crnn in strong_label_crnns for p in crnn.parameters() ])) print( 'Strong Label CNN2d Params', sum([ p.numel() for crnn in strong_label_crnns for p in crnn.cnn.cnn_2d.parameters() ])) data_provider = DESEDProvider.from_config(data_provider) data_provider.test_transform.label_encoder.initialize_labels() event_classes = data_provider.test_transform.label_encoder.inverse_label_mapping event_classes = [event_classes[i] for i in range(len(event_classes))] frame_shift = data_provider.test_transform.stft.shift frame_shift /= data_provider.audio_reader.target_sample_rate if not isinstance(dataset_name, list): dataset_name = [dataset_name] if ground_truth_filepath is None: ground_truth_filepath = len(dataset_name) * [ground_truth_filepath] elif isinstance(ground_truth_filepath, (str, Path)): ground_truth_filepath = [ground_truth_filepath] assert len(ground_truth_filepath) == len(dataset_name) if not isinstance(strong_pseudo_labeling, list): strong_pseudo_labeling = len(dataset_name) * [strong_pseudo_labeling] assert len(strong_pseudo_labeling) == len(dataset_name) if not isinstance(pseudo_labelled_dataset_name, list): pseudo_labelled_dataset_name = [pseudo_labelled_dataset_name] assert len(pseudo_labelled_dataset_name) == len(dataset_name) database = deepcopy(data_provider.db.data) for i in range(len(dataset_name)): print() print(dataset_name[i]) if dataset_name[i] == 'eval_public' and not ground_truth_filepath[i]: database_root = Path( data_provider.get_raw('eval_public')[0] ['audio_path']).parent.parent.parent.parent ground_truth_filepath[ i] = database_root / 'metadata' / 'eval' / 'public.tsv' elif dataset_name[i] == 'validation' and not ground_truth_filepath[i]: database_root = Path( data_provider.get_raw('validation')[0] ['audio_path']).parent.parent.parent.parent ground_truth_filepath[ i] = database_root / 'metadata' / 'validation' / 'validation.tsv' dataset = data_provider.get_dataset(dataset_name[i]) audio_durations = { example['example_id']: example['audio_length'] for example in data_provider.db.get_dataset(dataset_name[i]) } score_storage_dir = storage_dir / 'scores' / dataset_name[i] detection_storage_dir = storage_dir / 'detections' / dataset_name[i] if max_segment_length is None: timestamps = { audio_id: np.array([0., audio_durations[audio_id]]) for audio_id in audio_durations } else: timestamps = {} for audio_id in audio_durations: ts = np.arange( (2 + max_segment_length) * frame_shift, audio_durations[audio_id], (max_segment_length - segment_overlap) * frame_shift) timestamps[audio_id] = np.concatenate( ([0.], ts - segment_overlap / 2 * frame_shift, [audio_durations[audio_id]])) if max_segment_length is not None: dataset = dataset.map( partial(segment_batch, max_length=max_segment_length, overlap=segment_overlap)).unbatch() tags, tagging_scores, _ = tagging( weak_label_crnns, dataset, device, timestamps, event_classes, weak_label_crnn_hyper_params_dir, None, None, ) def add_tag_condition(example): example["tag_condition"] = np.array( [tags[example_id] for example_id in example["example_id"]]) return example dataset = dataset.map(add_tag_condition) timestamps = np.round(np.arange(0, 100000) * frame_shift, decimals=6) if not isinstance(sed_hyper_params_name, (list, tuple)): sed_hyper_params_name = [sed_hyper_params_name] events, sed_results = sound_event_detection( strong_label_crnns, dataset, device, timestamps, event_classes, tags, strong_label_crnn_hyper_params_dir, sed_hyper_params_name, ground_truth_filepath[i], audio_durations, collar_based_params, [psds_scenario_1, psds_scenario_2], max_segment_length=max_segment_length, segment_overlap=segment_overlap, pseudo_widening=pseudo_widening, score_storage_dir=[ score_storage_dir / name for name in sed_hyper_params_name ] if save_scores else None, detection_storage_dir=[ detection_storage_dir / name for name in sed_hyper_params_name ] if save_detections else None, ) for j, sed_results_j in enumerate(sed_results): if sed_results_j: dump_json( sed_results_j, storage_dir / f'sed_{sed_hyper_params_name[j]}_results_{dataset_name[i]}.json' ) if strong_pseudo_labeling[i]: database['datasets'][ pseudo_labelled_dataset_name[i]] = base.pseudo_label( database['datasets'][dataset_name[i]], event_classes, False, False, strong_pseudo_labeling[i], None, None, events[0], ) with (storage_dir / f'{dataset_name[i]}_pseudo_labeled.tsv').open('w') as fid: fid.write('filename\tonset\toffset\tevent_label\n') for key, event_list in events[0].items(): if len(event_list) == 0: fid.write(f'{key}.wav\t\t\t\n') for t_on, t_off, event_label in event_list: fid.write( f'{key}.wav\t{t_on}\t{t_off}\t{event_label}\n') if any(strong_pseudo_labeling): dump_json( database, storage_dir / Path(data_provider.json_path).name, create_path=True, indent=4, ensure_ascii=False, ) inference_dir = Path(strong_label_crnn_hyper_params_dir) / 'inference' os.makedirs(str(inference_dir), exist_ok=True) (inference_dir / storage_dir.name).symlink_to(storage_dir) emissions_tracker.stop() print(storage_dir)