def config(): debug = False timestamp = timeStamped('')[1:] + ('_debug' if debug else '') strong_label_crnn_hyper_params_dir = '' assert len( strong_label_crnn_hyper_params_dir ) > 0, 'Set strong_label_crnn_hyper_params_dir on the command line.' strong_label_crnn_tuning_config = load_json( Path(strong_label_crnn_hyper_params_dir) / '1' / 'config.json') strong_label_crnn_dirs = strong_label_crnn_tuning_config[ 'strong_label_crnn_dirs'] assert len(strong_label_crnn_dirs ) > 0, 'strong_label_crnn_dirs must not be empty.' strong_label_crnn_checkpoints = strong_label_crnn_tuning_config[ 'strong_label_crnn_checkpoints'] data_provider = strong_label_crnn_tuning_config['data_provider'] database_name = strong_label_crnn_tuning_config['database_name'] storage_dir = str(storage_root / 'strong_label_crnn' / database_name / 'inference' / timestamp) assert not Path(storage_dir).exists() weak_label_crnn_hyper_params_dir = strong_label_crnn_tuning_config[ 'weak_label_crnn_hyper_params_dir'] assert len( weak_label_crnn_hyper_params_dir ) > 0, 'Set weak_label_crnn_hyper_params_dir on the command line.' weak_label_crnn_tuning_config = load_json( Path(weak_label_crnn_hyper_params_dir) / '1' / 'config.json') weak_label_crnn_dirs = weak_label_crnn_tuning_config['crnn_dirs'] assert len( weak_label_crnn_dirs) > 0, 'weak_label_crnn_dirs must not be empty.' weak_label_crnn_checkpoints = weak_label_crnn_tuning_config[ 'crnn_checkpoints'] del strong_label_crnn_tuning_config del weak_label_crnn_tuning_config sed_hyper_params_name = ['f', 'psds1', 'psds2'] device = 0 dataset_name = 'eval_public' ground_truth_filepath = None max_segment_length = None if max_segment_length is None: segment_overlap = None else: segment_overlap = 100 save_scores = False save_detections = False weak_pseudo_labeling = False strong_pseudo_labeling = False pseudo_labelled_dataset_name = dataset_name pseudo_widening = .0 ex.observers.append(FileStorageObserver.create(storage_dir))
def config(): debug = False timestamp = timeStamped('')[1:] + ('_debug' if debug else '') weak_label_crnn_hyper_params_dir = '' assert len( weak_label_crnn_hyper_params_dir ) > 0, 'Set weak_label_crnn_hyper_params_dir on the command line.' weak_label_crnn_tuning_config = load_json( Path(weak_label_crnn_hyper_params_dir) / '1' / 'config.json') weak_label_crnn_dirs = weak_label_crnn_tuning_config['crnn_dirs'] assert len( weak_label_crnn_dirs) > 0, 'weak_label_crnn_dirs must not be empty.' weak_label_crnn_checkpoints = weak_label_crnn_tuning_config[ 'crnn_checkpoints'] del weak_label_crnn_tuning_config strong_label_crnn_group_dir = '' if isinstance(strong_label_crnn_group_dir, list): strong_label_crnn_dirs = sorted([ str(d) for g in strong_label_crnn_group_dir for d in Path(g).glob('202*') if d.is_dir() ]) else: strong_label_crnn_dirs = sorted([ str(d) for d in Path(strong_label_crnn_group_dir).glob('202*') if d.is_dir() ]) assert len(strong_label_crnn_dirs ) > 0, 'strong_label_crnn_dirs must not be empty.' strong_label_crnn_checkpoints = 'ckpt_best_macro_fscore_strong.pth' strong_crnn_config = load_json( Path(strong_label_crnn_dirs[0]) / '1' / 'config.json') data_provider = strong_crnn_config['data_provider'] database_name = strong_crnn_config.get('database_name', 'desed') storage_dir = str(storage_root / 'strong_label_crnn' / database_name / 'hyper_params' / timestamp) assert not Path(storage_dir).exists() del strong_crnn_config data_provider['min_audio_length'] = .01 data_provider['cached_datasets'] = None device = 0 validation_set_name = 'validation' validation_ground_truth_filepath = None eval_set_name = 'eval_public' eval_ground_truth_filepath = None medfilt_lengths = [31] if debug else [ 301, 251, 201, 151, 101, 81, 61, 51, 41, 31, 21, 11 ] ex.observers.append(FileStorageObserver.create(storage_dir))
def config(): debug = False timestamp = timeStamped('')[1:] + ('_debug' if debug else '') group_dir = '' if isinstance(group_dir, list): crnn_dirs = sorted([ str(d) for g in group_dir for d in Path(g).glob('202*') if d.is_dir() ]) else: crnn_dirs = sorted( [str(d) for d in Path(group_dir).glob('202*') if d.is_dir()]) assert len(crnn_dirs) > 0, 'crnn_dirs must not be empty.' crnn_checkpoints = 'ckpt_best_macro_fscore_weak.pth' crnn_config = load_json(Path(crnn_dirs[0]) / '1' / 'config.json') data_provider = crnn_config['data_provider'] database_name = crnn_config.get('database_name', 'desed') storage_dir = str(storage_root / 'weak_label_crnn' / database_name / 'hyper_params' / timestamp) assert not Path(storage_dir).exists() del crnn_config data_provider['min_audio_length'] = .01 data_provider['cached_datasets'] = None device = 0 validation_set_name = 'validation' validation_ground_truth_filepath = None eval_set_name = 'eval_public' eval_ground_truth_filepath = None boundaries_filter_lengths = [20] if debug else [ 100, 80, 60, 50, 40, 30, 20, 10, 0 ] tune_detection_scenario_1 = True detection_window_lengths_scenario_1 = [11] if debug else [ 51, 41, 31, 21, 11 ] detection_window_shift_scenario_1 = 1 detection_medfilt_lengths_scenario_1 = [11] if debug else [ 101, 81, 61, 51, 41, 31, 21, 11 ] tune_detection_scenario_2 = True detection_window_lengths_scenario_2 = [250] detection_window_shift_scenario_2 = 250 detection_medfilt_lengths_scenario_2 = [1] ex.observers.append(FileStorageObserver.create(storage_dir))
def tagging( crnns, dataset, device, timestamps, event_classes, hyper_params_dir, ground_truth, audio_durations, psds_params=(), max_segment_length=None, segment_overlap=None, ): print() print('Tagging') hyper_params = load_json( Path(hyper_params_dir) / 'tagging_hyper_params_f.json') thresholds = { event_class: hyper_params[event_class]['threshold'] for event_class in hyper_params } tagging_scores = base.tagging( crnns, dataset, device, max_segment_length=max_segment_length, segment_overlap=segment_overlap, merge_score_segments=False, ) results = {} if ground_truth is not None: tagging_scores_merged = merge_segments(tagging_scores, segment_overlap=0) tagging_scores_df = base.scores_to_dataframes( tagging_scores_merged, timestamps=timestamps, event_classes=event_classes, ) if ground_truth: f, p, r, stats = clip_based.fscore(tagging_scores_df, ground_truth, thresholds, num_jobs=8) print('f', f) print('p', p) print('r', r) for key in f: results.update({ f'{key}_f': f[key], f'{key}_p': p[key], f'{key}_r': r[key], }) for j in range(len(psds_params)): psds, psd_roc, classwise_rocs = intersection_based.psds( tagging_scores_df, ground_truth, audio_durations, **psds_params[j], num_jobs=8, ) print(f'psds[{j}]', psds) results[f'psds[{j}]'] = psds for event_class, (tpr, efpr, *_) in classwise_rocs.items(): results[f'{event_class}_auc[{j}]'] = staircase_auc( tpr, efpr, psds_params[j].get('max_efpr', 100)) psds, psd_roc, classwise_rocs = intersection_based.reference.approximate_psds( tagging_scores_df, ground_truth, audio_durations, **psds_params[j], thresholds=np.linspace(.01, .99, 50), ) print(f'approx_psds[{j}]', psds) results[f'approx_psds[{j}]'] = psds for event_class, (tpr, efpr, *_) in classwise_rocs.items(): results[f'{event_class}_approx_auc[{j}]'] = staircase_auc( tpr, efpr, psds_params[j].get('max_efpr', 100)) thresholds = np.array( [thresholds[event_class] for event_class in event_classes]) tagging_scores = { audio_id: tagging_scores[audio_id][0] for audio_id in tagging_scores.keys() } tags = { audio_id: tagging_scores[audio_id] > thresholds for audio_id in tagging_scores.keys() } return tags, tagging_scores, results
def sound_event_detection( crnns, dataset, device, timestamps, event_classes, tags, hyper_params_dir, hyper_params_name, ground_truth, audio_durations, collar_based_params=(), psds_params=(), max_segment_length=None, segment_overlap=None, pseudo_widening=.0, score_storage_dir=None, detection_storage_dir=None, ): print() print('Sound Event Detection') if isinstance(hyper_params_name, (str, Path)): hyper_params_name = [hyper_params_name] assert isinstance(hyper_params_name, (list, tuple)) hyper_params = [ load_json(Path(hyper_params_dir) / f'sed_hyper_params_{name}.json') for name in hyper_params_name ] if isinstance(score_storage_dir, (list, tuple)): assert len(score_storage_dir) == len(hyper_params), ( len(score_storage_dir), len(hyper_params)) elif isinstance(score_storage_dir, (str, Path)): score_storage_dir = [ Path(score_storage_dir) / name for name in hyper_params_name ] elif score_storage_dir is not None: raise ValueError('score_storage_dir must be list, str, Path or None.') if isinstance(detection_storage_dir, (list, tuple)): assert len(detection_storage_dir) == len(hyper_params), ( len(detection_storage_dir), len(hyper_params)) elif isinstance(detection_storage_dir, (str, Path)): detection_storage_dir = [ Path(detection_storage_dir) / name for name in hyper_params_name ] elif detection_storage_dir is not None: raise ValueError( 'detection_storage_dir must be list, str, Path or None.') window_lengths = np.zeros((len(hyper_params), len(event_classes))) medfilt_lengths = np.zeros((len(hyper_params), len(event_classes))) tag_masked = np.zeros((len(hyper_params), len(event_classes))) window_shift = set() for i, hyper_params_i in enumerate(hyper_params): for j, event_class in enumerate(event_classes): window_lengths[i, j] = hyper_params_i[event_class]['window_length'] medfilt_lengths[i, j] = hyper_params_i[event_class]['medfilt_length'] tag_masked[i, j] = hyper_params_i[event_class]['tag_masked'] window_shift.add(hyper_params_i[event_class]['window_shift']) if not len(window_shift) == 1: raise ValueError( 'Inference with multiple window shifts is not supported.') window_shift = list(window_shift)[0] if max_segment_length is not None: assert max_segment_length % window_shift == 0, (max_segment_length, window_shift) assert (segment_overlap // 2) % window_shift == 0, (segment_overlap, window_shift) detection_scores = base.sound_event_detection( crnns, dataset, device, model_kwargs={ 'window_length': window_lengths, 'window_shift': window_shift }, medfilt_length=medfilt_lengths, apply_mask=tag_masked, masks=tags, timestamps=timestamps[::window_shift], event_classes=event_classes, max_segment_length=max_segment_length, segment_overlap=segment_overlap, merge_score_segments=True, score_segment_overlap=segment_overlap // window_shift, score_storage_dir=score_storage_dir, ) event_detections = [] results = [] for i, name in enumerate(hyper_params_name): if ground_truth: print() print(name) results.append({}) if detection_storage_dir and detection_storage_dir[i]: io.write_detections_for_multiple_thresholds( detection_scores[i], thresholds=np.linspace(.01, .99, 50), dir_path=detection_storage_dir[i], ) if 'threshold' in hyper_params[i][event_classes[0]]: thresholds = { event_class: hyper_params[i][event_class]['threshold'] for event_class in event_classes } event_detections.append( scores_to_event_list( detection_scores[i], thresholds, event_classes=event_classes, )) if detection_storage_dir and detection_storage_dir[i]: io.write_detection( detection_scores[i], thresholds, Path(detection_storage_dir[i]) / 'cbf.tsv', ) if ground_truth and collar_based_params: f, p, r, stats = collar_based.fscore( detection_scores[i], ground_truth, thresholds, **collar_based_params, return_onset_offset_dist_sum=True, num_jobs=8, ) print('f', f) print('p', p) print('r', r) for key in f: results[-1].update({ f'{key}_f': f[key], f'{key}_p': p[key], f'{key}_r': r[key], }) if key in stats: results[-1].update({ f'{key}_onset_bias': stats[key]['onset_dist_sum'] / max(stats[key]['tps'], 1), f'{key}_offset_bias': stats[key]['offset_dist_sum'] / max(stats[key]['tps'], 1), }) for clip_id in event_detections[-1]: events_in_clip = [] for onset, offset, event_label in event_detections[-1][ clip_id]: onset = max( onset - pseudo_widening - hyper_params[i][event_label].get('onset_bias', 0), 0) offset = offset + pseudo_widening - hyper_params[i][ event_label].get('offset_bias', 0) if offset > onset: events_in_clip.append((onset, offset, event_label)) event_detections[-1][clip_id] = events_in_clip else: event_detections.append(None) if ground_truth: if not isinstance(psds_params, (tuple, list)): psds_params = [psds_params] for j in range(len(psds_params)): psds, psd_roc, classwise_rocs = intersection_based.psds( detection_scores[i], ground_truth, audio_durations, **psds_params[j], num_jobs=8, ) print(f'psds[{j}]', psds) results[-1][f'psds[{j}]'] = psds for event_class, (tpr, efpr, *_) in classwise_rocs.items(): results[-1][f'{event_class}_auc[{j}]'] = staircase_auc( tpr, efpr, psds_params[j].get('max_efpr', 100)) if score_storage_dir and score_storage_dir[i] is not None: psds, psd_roc, classwise_rocs = intersection_based.psds( score_storage_dir[i], ground_truth, audio_durations, **psds_params[j], num_jobs=8, ) print(f'psds[{j}] (from files)', psds) psds, psd_roc, classwise_rocs = intersection_based.reference.approximate_psds( detection_scores[i], ground_truth, audio_durations, **psds_params[j], thresholds=np.linspace(.01, .99, 50), ) print(f'approx_psds[{j}]', psds) results[-1][f'approx_psds[{j}]'] = psds for event_class, (tpr, efpr, *_) in classwise_rocs.items(): results[-1][ f'{event_class}_approx_auc[{j}]'] = staircase_auc( tpr, efpr, psds_params[j].get('max_efpr', 100)) if detection_storage_dir and detection_storage_dir[ i] is not None: psds, psd_roc, classwise_rocs = intersection_based.reference.approximate_psds_from_detections_dir( detection_storage_dir[i], ground_truth, audio_durations, **psds_params[j], thresholds=np.linspace(.01, .99, 50), ) print(f'approx_psds[{j}] (from files)', psds) return event_detections, results
def boundaries_detection( crnns, dataset, device, timestamps, event_classes, tags, hyper_params_dir, ground_truth, collar_based_params, max_segment_length=None, segment_overlap=None, pseudo_widening=.0, ): print() print('Boundaries Detection') hyper_params = load_json( Path(hyper_params_dir) / f'boundaries_detection_hyper_params_f.json') stepfilt_length = np.array([ hyper_params[event_class]['stepfilt_length'] for event_class in event_classes ]) thresholds = { event_class: hyper_params[event_class]['threshold'] for event_class in event_classes } boundary_scores = base.boundaries_detection( crnns, dataset, device, stepfilt_length=stepfilt_length, apply_mask=True, masks=tags, max_segment_length=max_segment_length, segment_overlap=segment_overlap, merge_score_segments=True, timestamps=timestamps, event_classes=event_classes, ) results = {} if ground_truth: boundary_ground_truth = base.tuning.boundaries_from_events( ground_truth) f, p, r, stats = collar_based.fscore( boundary_scores, boundary_ground_truth, thresholds, **collar_based_params, return_onset_offset_dist_sum=True, num_jobs=8, ) print('f', f) print('p', p) print('r', r) for key in f: results.update({ f'{key}_f': f[key], f'{key}_p': p[key], f'{key}_r': r[key], }) if key in stats: results.update({ f'{key}_onset_bias': stats[key]['onset_dist_sum'] / max(stats[key]['tps'], 1), f'{key}_offset_bias': stats[key]['offset_dist_sum'] / max(stats[key]['tps'], 1), }) boundaries_detection = scores_to_event_list(boundary_scores, thresholds, event_classes=event_classes) for clip_id in boundaries_detection: boundaries_in_clip = [] for onset, offset, event_label in boundaries_detection[clip_id]: onset = max( np.round( onset - pseudo_widening - hyper_params[event_label]['onset_bias'], 3), 0) offset = np.round( offset + pseudo_widening - hyper_params[event_label]['offset_bias'], 3) if offset > onset: boundaries_in_clip.append((onset, offset, event_label)) boundaries_detection[clip_id] = boundaries_in_clip return boundaries_detection, results