def filter_dataset(input_path, output_path, duration_in_hours, cer, seed): dataset = transcripts.load(input_path) random.seed(seed) random.shuffle(dataset) print('initial set hours: ', sum(transcripts.compute_duration(t, hours=True) for t in dataset), 'hours') if cer: dataset = [e for e in dataset if e['cer'] <= cer] print( 'after cer filtering hours: ', sum(transcripts.compute_duration(t, hours=True) for t in dataset), 'hours') if duration_in_hours is not None: s = [] set_duration = 0 while set_duration <= duration_in_hours and len(dataset) > 0: t = dataset.pop() set_duration += transcripts.compute_duration(t, hours=True) s.append(t) dataset = s print('after duration filtering hours: ', sum(transcripts.compute_duration(t, hours=True) for t in dataset), 'hours') print(output_path) transcripts.save(output_path, dataset)
def eval(ref, hyp, html, debug_audio, sample_rate=100): if os.path.isfile(ref) and os.path.isfile(hyp): print(der(ref_rttm_path=ref, hyp_rttm_path=hyp)) elif os.path.isdir(ref) and os.path.isdir(hyp): errs = [] diarization_transcript = [] for rttm in os.listdir(ref): if not rttm.endswith('.rttm'): continue print(rttm) audio_path = transcripts.load( os.path.join(hyp, rttm).replace('.rttm', '.json'))[0]['audio_path'] ref_rttm_path, hyp_rttm_path = os.path.join(ref, rttm), os.path.join( hyp, rttm) ref_transcript, hyp_transcript = map( transcripts.load, [ref_rttm_path, hyp_rttm_path]) ser_err, hyp_perm = speaker_error( ref=ref_transcript, hyp=hyp_transcript, num_speakers=2, sample_rate=sample_rate, ignore_silence_and_overlapped_speech=True) der_err, *_ = speaker_error( ref=ref_transcript, hyp=hyp_transcript, num_speakers=2, sample_rate=sample_rate, ignore_silence_and_overlapped_speech=False) der_err_ = der(ref_rttm_path=ref_rttm_path, hyp_rttm_path=hyp_rttm_path) transcripts.remap_speaker(hyp_transcript, hyp_perm) err = dict(ser=ser_err, der=der_err, der_=der_err_) diarization_transcript.append( dict(audio_path=audio_path, audio_name=transcripts.audio_name(audio_path), ref=ref_transcript, hyp=hyp_transcript, **err)) print(rttm, '{ser:.2f}, {der:.2f} | {der_:.2f}'.format(**err)) print() errs.append(err) print('===') print({ k: sum(e) / len(e) for k in errs[0] for e in [[err[k] for err in errs]] }) if html: print( vis.diarization( sorted(diarization_transcript, key=lambda t: t['ser'], reverse=True), html, debug_audio))
def main(args, ext_json=['.json', '.json.gz']): utils.enable_jit_fusion() assert args.output_json or args.output_html or args.output_txt or args.output_csv, \ 'at least one of the output formats must be provided' os.makedirs(args.output_path, exist_ok=True) audio_data_paths = set( p for f in args.input_path for p in ([os.path.join(f, g) for g in os.listdir(f)] if os.path.isdir(f) else [f]) if os.path.isfile(p) and any(map(p.endswith, args.ext))) json_data_paths = set( p for p in args.input_path if any(map(p.endswith, ext_json)) and not utils.strip_suffixes(p, ext_json) in audio_data_paths) data_paths = list(audio_data_paths | json_data_paths) exclude = set([ os.path.splitext(basename)[0] for basename in os.listdir(args.output_path) if basename.endswith('.json') ]) if args.skip_processed else None data_paths = [ path for path in data_paths if exclude is None or os.path.basename(path) not in exclude ] text_pipeline, frontend, model, generator = setup(args) val_dataset = datasets.AudioTextDataset( data_paths, [text_pipeline], args.sample_rate, frontend=frontend if not args.frontend_in_model else None, mono=args.mono, time_padding_multiple=args.batch_time_padding_multiple, audio_backend=args.audio_backend, exclude=exclude, max_duration=args.transcribe_first_n_sec, mode='batched_channels' if args.join_transcript else 'batched_transcript', string_array_encoding=args.dataset_string_array_encoding, debug_short_long_records_features_from_whole_normalized_signal=args. debug_short_long_records_features_from_whole_normalized_signal) print('Examples count: ', len(val_dataset)) val_meta = val_dataset.pop_meta() val_data_loader = torch.utils.data.DataLoader( val_dataset, batch_size=None, collate_fn=val_dataset.collate_fn, num_workers=args.num_workers) csv_sep = dict(tab='\t', comma=',')[args.csv_sep] csv_lines = [] # only used if args.output_csv is True oom_handler = utils.OomHandler(max_retries=args.oom_retries) for i, (meta, s, x, xlen, y, ylen) in enumerate(val_data_loader): print(f'Processing: {i}/{len(val_dataset)}') meta = [val_meta[t['example_id']] for t in meta] audio_path = meta[0]['audio_path'] audio_name = transcripts.audio_name(audio_path) begin_end = [dict(begin=t['begin'], end=t['end']) for t in meta] begin = torch.tensor([t['begin'] for t in begin_end], dtype=torch.float) end = torch.tensor([t['end'] for t in begin_end], dtype=torch.float) #TODO WARNING assumes frontend not in dataset if not args.frontend_in_model: print('\n' * 10 + 'WARNING\n' * 5) print( 'transcribe.py assumes frontend in model, in other case time alignment was incorrect' ) print('WARNING\n' * 5 + '\n') duration = x.shape[-1] / args.sample_rate channel = [t['channel'] for t in meta] speaker = [t['speaker'] for t in meta] speaker_name = [t['speaker_name'] for t in meta] if x.numel() == 0: print(f'Skipping empty [{audio_path}].') continue try: tic = time.time() y, ylen = y.to(args.device), ylen.to(args.device) log_probs, logits, olen = model( x.squeeze(1).to(args.device), xlen.to(args.device)) print('Input:', audio_name) print('Input time steps:', log_probs.shape[-1], '| target time steps:', y.shape[-1]) print( 'Time: audio {audio:.02f} sec | processing {processing:.02f} sec' .format(audio=sum(map(transcripts.compute_duration, meta)), processing=time.time() - tic)) ts: shaping.Bt = duration * torch.linspace( 0, 1, steps=log_probs.shape[-1], device=log_probs.device).unsqueeze(0).expand(x.shape[0], -1) ref_segments = [[ dict(channel=channel[i], begin=begin_end[i]['begin'], end=begin_end[i]['end'], ref=text_pipeline.postprocess( text_pipeline.preprocess(meta[i]['ref']))) ] for i in range(len(meta))] hyp_segments = [ alternatives[0] for alternatives in generator.generate( tokenizer=text_pipeline.tokenizer, log_probs=log_probs, begin=begin, end=end, output_lengths=olen, time_stamps=ts, segment_text_key='hyp', segment_extra_info=[ dict(speaker=s, speaker_name=sn, channel=c) for s, sn, c in zip(speaker, speaker_name, channel) ]) ] hyp_segments = [ transcripts.map_text(text_pipeline.postprocess, hyp=hyp) for hyp in hyp_segments ] hyp, ref = '\n'.join( transcripts.join(hyp=h) for h in hyp_segments).strip(), '\n'.join( transcripts.join(ref=r) for r in ref_segments).strip() if args.verbose: print('HYP:', hyp) print('CER: {cer:.02%}'.format(cer=metrics.cer(hyp=hyp, ref=ref))) tic_alignment = time.time() if args.align and y.numel() > 0: alignment: shaping.BY = ctc.alignment( log_probs.permute(2, 0, 1), y[:, 0, :], # assumed that 0 channel is char labels olen, ylen[:, 0], blank=text_pipeline.tokenizer.eps_id, pack_backpointers=args.pack_backpointers) aligned_ts: shaping.Bt = ts.gather(1, alignment) ref_segments = [ alternatives[0] for alternatives in generator.generate( tokenizer=text_pipeline.tokenizer, log_probs=torch.nn.functional.one_hot( y[:, 0, :], num_classes=log_probs.shape[1]).permute( 0, 2, 1), begin=begin, end=end, output_lengths=ylen, time_stamps=aligned_ts, segment_text_key='ref', segment_extra_info=[ dict(speaker=s, speaker_name=sn, channel=c) for s, sn, c in zip(speaker, speaker_name, channel) ]) ] ref_segments = [ transcripts.map_text(text_pipeline.postprocess, hyp=ref) for ref in ref_segments ] oom_handler.reset() except: if oom_handler.try_recover(model.parameters()): print(f'Skipping {i} / {len(val_dataset)}') continue else: raise print('Alignment time: {:.02f} sec'.format(time.time() - tic_alignment)) ref_transcript, hyp_transcript = [ sorted(transcripts.flatten(segments), key=transcripts.sort_key) for segments in [ref_segments, hyp_segments] ] if args.max_segment_duration: if ref: ref_segments = list( transcripts.segment_by_time(ref_transcript, args.max_segment_duration)) hyp_segments = list( transcripts.segment_by_ref(hyp_transcript, ref_segments)) else: hyp_segments = list( transcripts.segment_by_time(hyp_transcript, args.max_segment_duration)) ref_segments = [[] for _ in hyp_segments] #### HACK for diarization elif args.ref_transcript_path and args.join_transcript: audio_name_hack = audio_name.split('.')[0] #TODO: normalize ref field ref_segments = [[t] for t in sorted(transcripts.load( os.path.join(args.ref_transcript_path, audio_name_hack + '.json')), key=transcripts.sort_key)] hyp_segments = list( transcripts.segment_by_ref(hyp_transcript, ref_segments, set_speaker=True, soft=False)) #### END OF HACK has_ref = bool(transcripts.join(ref=transcripts.flatten(ref_segments))) transcript = [] for hyp_transcript, ref_transcript in zip(hyp_segments, ref_segments): hyp, ref = transcripts.join(hyp=hyp_transcript), transcripts.join( ref=ref_transcript) transcript.append( dict(audio_path=audio_path, ref=ref, hyp=hyp, speaker_name=transcripts.speaker_name(ref=ref_transcript, hyp=hyp_transcript), words=metrics.align_words( *metrics.align_strings(hyp=hyp, ref=ref)) if args.align_words else [], words_ref=ref_transcript, words_hyp=hyp_transcript, **transcripts.summary(hyp_transcript), **(dict(cer=metrics.cer(hyp=hyp, ref=ref)) if has_ref else {}))) transcripts.collect_speaker_names(transcript, set_speaker_data=True, num_speakers=2) filtered_transcript = list( transcripts.prune(transcript, align_boundary_words=args.align_boundary_words, cer=args.prune_cer, duration=args.prune_duration, gap=args.prune_gap, allowed_unk_count=args.prune_unk, num_speakers=args.prune_num_speakers)) print('Filtered segments:', len(filtered_transcript), 'out of', len(transcript)) if args.output_json: transcript_path = os.path.join(args.output_path, audio_name + '.json') print(transcripts.save(transcript_path, filtered_transcript)) if args.output_html: transcript_path = os.path.join(args.output_path, audio_name + '.html') print( vis.transcript(transcript_path, args.sample_rate, args.mono, transcript, filtered_transcript)) if args.output_txt: transcript_path = os.path.join(args.output_path, audio_name + '.txt') with open(transcript_path, 'w') as f: f.write(' '.join(t['hyp'].strip() for t in filtered_transcript)) print(transcript_path) if args.output_csv: assert len({t['audio_path'] for t in filtered_transcript}) == 1 audio_path = filtered_transcript[0]['audio_path'] hyp = ' '.join(t['hyp'].strip() for t in filtered_transcript) begin = min(t['begin'] for t in filtered_transcript) end = max(t['end'] for t in filtered_transcript) csv_lines.append( csv_sep.join([audio_path, hyp, str(begin), str(end)])) if args.logits: logits_file_path = os.path.join(args.output_path, audio_name + '.pt') if args.logits_crop: begin_end = [ dict( zip(['begin', 'end'], [ t['begin'] + c / float(o) * (t['end'] - t['begin']) for c in args.logits_crop ])) for o, t in zip(olen, begin_end) ] logits_crop = [slice(*args.logits_crop) for o in olen] else: logits_crop = [slice(int(o)) for o in olen] # TODO: filter ref / hyp by channel? torch.save([ dict(audio_path=audio_path, logits=l[..., logits_crop[i]], **begin_end[i], ref=ref, hyp=hyp) for i, l in enumerate(logits.cpu()) ], logits_file_path) print(logits_file_path) print('Done: {:.02f} sec\n'.format(time.time() - tic)) if args.output_csv: transcript_path = os.path.join(args.output_path, 'transcripts.csv') with open(transcript_path, 'w') as f: f.write('\n'.join(csv_lines)) print(transcript_path)
def errors( input_paths: typing.List[str], output_path: typing.Optional[str] = None, include_metrics: typing.List[str] = ('cer', 'wer',), debug_audio: bool = False, filter_fn: typing.Optional[typing.Callable[[typing.Tuple[dict]], bool]] = lambda x: True, sort_fn: typing.Optional[typing.Callable[[typing.List[typing.Tuple[dict]]], typing.List[typing.Tuple[dict]]]] = lambda x: x ) -> str: ''' Parameters: input_paths: paths to json files with list of analyzed examples output_path: path to output html (default: input_path[0]+.html) debug_audio: include audio data into html if true filter_fn: function to filter tuples of examples grouped by `audio_path`, function input: tuple of examples in order same as in `input_paths` function output: true to include examples into html, false otherwise sort_fn: function to sort tuples of examples grouped by `audio_path`, function input: list of tuples of examples, each tuple has same order as in `input_paths` function output: same list but in sorted order ''' grouped_examples = collections.defaultdict(list) examples_count = {} for path in input_paths: examples = transcripts.load(path) examples_count[path] = len(examples) for example in examples: grouped_examples[example['audio_path']].append(example) grouped_examples = list(filter(lambda x: len(x) == len(input_paths), grouped_examples.values())) not_found_examples_count = {path: count - len(grouped_examples) for path, count in examples_count.items()} grouped_examples = list(filter(filter_fn, grouped_examples)) filtered_examples_count = {path: count - len(grouped_examples) - not_found_examples_count[path] for path, count in examples_count.items()} grouped_examples = sort_fn(grouped_examples) style = ''' .filters_table b.warning {color: red;} table.metrics_table {border-collapse:collapse;} .metrics_table th {padding: 5px; padding-left: 10px; text-align: left} .metrics_table tr {padding: 5px;} .metrics_table tr.new_section {border-top: 1px solid black; padding: 5px;} .metrics_table td {border-left: 1px dashed black; border-right: 1px dashed black; padding: 5px; padding-left: 10px;} ''' template = ''' <html> <head> <meta charset="utf-8"> <style> {style} </style> <script> {scripts} </script> </head> <body> <b style="padding: 10px">Filters</b><br><br> Dropped (example not found in other files):<br> <table class="filters_table"> {filter_not_found_table} </table><br> Dropped (filter_fn): <table class="filters_table"> {filter_fn_table} </table><br> <table class="metrics_table"> {metrics_table} </table> </body> </html> ''' # Make filter "not found" table def fmt_filter_table(filtered_count: dict) -> str: filtered_table = [] for file_path, count in filtered_count.items(): css_class = 'warning' if count > 0 else '' file_name = os.path.basename(file_path) filtered_table.append(f'<tr><td>{file_name}</td><td><b class="{css_class}">{count}</b></td></tr>') return '\n'.join(filtered_table) filter_not_found_table = fmt_filter_table(not_found_examples_count) # Make filter "filter_fn" table filter_fn_table = fmt_filter_table(filtered_examples_count) # Make averages table def fmt_averages_table(include_metrics: typing.List[str], averages: dict) -> str: header = '<tr><th>Averages</th>' + '<th></th>' * (len(include_metrics) + 2) + '</tr>\n' header += '<tr><th></th>' + ''.join(f'<th>{metric_name}</th>' for metric_name in include_metrics) + '<th></th>' * 2 + '</tr>\n' content = [] for i, (file_name, metric_values) in enumerate(averages.items()): content_line = f'<td><b>{file_name}</b></td>' + ''.join( f'<td>{metric_value:.2%}</td>' for metric_value in metric_values) + '<td></td>' * 2 if i == 0: content_line = '<tr class="new_section">' + content_line + '</tr>' else: content_line = '<tr>' + content_line + '</tr>' content.append(content_line) content = '\n'.join(content) footer = '<tr class="new_section" style="height: 30px">' + '<th></th>' * (len(include_metrics) + 3) + '</tr>\n' return header + content + footer averages = {} for i, input_file in enumerate(input_paths): file_name = os.path.basename(input_file) file_examples = [examples[i] for examples in grouped_examples] averages[file_name] = [metrics.nanmean(file_examples, metric_name) for metric_name in include_metrics] average_table = fmt_averages_table(include_metrics, averages) # Make examples table def fmt_examples_table(include_metrics: typing.List[str], table_data: typing.List[dict], debug_audio: bool) -> str: header = '<tr><th>Examples</th>' + '<th></th>' * (len(include_metrics) + 2) + '</tr>\n' content = [] for i, examples_data in enumerate(table_data): ref = '<pre>' + examples_data['ref'] + '</pre>' audio_path = examples_data['audio_path'] embedded_audio = fmt_audio(audio_path, i) if debug_audio else '' examples_header = f'<tr class="new_section"><td colspan="{len(include_metrics)+1}"><b>{i}.</b>{audio_path}</td><td>{embedded_audio}</td><td>ref: <pre>{ref}</pre></td></tr>' examples_content = [] for i, example_data in enumerate(examples_data['examples']): metric_values = [f'{value:.2%}' if value is not None else '-' for value in example_data['metric_values']] file_name = example_data['file_name'] alignment = example_data['alignment'] hyp = '<pre>' + example_data['hyp'] + '</pre>' content_line = (f'<td>{file_name}</td>' + ''.join(map('<td>{}</td>'.format, metric_values)) + f'<td>{alignment}</td><td>{hyp}</td>') if i == 0: examples_content.append('<tr class="new_section">' + content_line + '</tr>') else: examples_content.append('<tr>' + content_line + '</tr>') content.append(examples_header) content.extend(examples_content) return header + '\n'.join(content) table_data = [] for examples in grouped_examples: examples_data = dict( audio_path = examples[0]['audio_path'], ref = examples[0]['ref_orig'], examples = []) for i, input_file in enumerate(input_paths): examples_data['examples'].append(dict( file_name = os.path.basename(input_file), metric_values = [metrics.extract_metric_value(examples[i], metric_name) for metric_name in include_metrics], alignment = fmt_alignment(examples[i]['alignment']), hyp = examples[i]["hyp"])) table_data.append(examples_data) examples_data = fmt_examples_table(include_metrics, table_data, debug_audio) # make output html metrics_table = average_table + examples_data report = template.format(style = style, scripts = play_script if debug_audio else '', filter_not_found_table = filter_not_found_table, filter_fn_table = filter_fn_table, metrics_table = metrics_table) html_path = output_path or (input_paths[0] + '.html') open(html_path, 'w').write(report) return html_path
def __init__( self, data_paths: typing.List[str], text_pipelines: typing.List[text_processing.ProcessingPipeline], sample_rate: int, mode: str = DEFAULT_MODE, frontend: typing.Optional[torch.nn.Module] = None, speaker_names: typing.Optional[typing.List[str]] = None, max_audio_file_size: typing.Optional[float] = None, #bytes min_duration: typing.Optional[float] = None, max_duration: typing.Optional[float] = None, max_num_channels: int = 2, mono: bool = True, audio_dtype: str = 'float32', time_padding_multiple: int = 1, audio_backend: typing.Optional[str] = None, exclude: typing.Optional[typing.Set] = None, bucket_fn: typing.Callable[[typing.List[typing.Dict]], int] = lambda transcript: 0, pop_meta: bool = False, string_array_encoding: str = 'utf_16_le', _print: typing.Callable = print, debug_short_long_records_features_from_whole_normalized_signal: bool = False, duration_from_transcripts: bool = False, ): self.debug_short_long_records_features_from_whole_normalized_signal = debug_short_long_records_features_from_whole_normalized_signal self.mode = mode self.min_duration = min_duration self.max_duration = max_duration self.max_audio_file_size = max_audio_file_size self.text_pipelines = text_pipelines self.frontend = frontend self.sample_rate = sample_rate self.time_padding_multiple = time_padding_multiple self.mono = mono self.audio_backend = audio_backend self.audio_dtype = audio_dtype data_paths = data_paths if isinstance(data_paths, list) else [data_paths] data_paths_ = [] for data_path in data_paths: if os.path.isdir(data_path): data_paths_.extend( os.path.join(data_path, filename) for filename in filter( audio.is_audio, os.listdir(data_path))) else: data_paths_.append(data_path) data_paths = data_paths_ tic = time.time() segments = [] for path in data_paths: if audio.is_audio(path): assert self.mono or self.mode != AudioTextDataset.DEFAULT_MODE, 'Only mono audio files allowed as dataset input in default mode' if self.mono: transcript = [ dict(audio_path=path, channel=transcripts.channel_missing) ] else: transcript = [ dict(audio_path=path, channel=c) for c in range(max_num_channels) ] else: transcript = transcripts.load(path) segments.extend(transcript) _print('Dataset reading time: ', time.time() - tic) tic = time.time() # get_or_else required because dictionary could contain None values which we want to replace. # dict.get doesnt work in this case get_or_else = lambda dictionary, key, default: dictionary[ key] if dictionary.get(key) is not None else default for t in segments: t['ref'] = get_or_else(t, 'ref', transcripts.ref_missing) t['begin'] = get_or_else(t, 'begin', transcripts.time_missing) t['end'] = get_or_else(t, 'end', transcripts.time_missing) t['channel'] = get_or_else( t, 'channel', transcripts.channel_missing ) if not self.mono else transcripts.channel_missing transcripts.collect_speaker_names(segments, speaker_names=speaker_names or [], num_speakers=max_num_channels, set_speaker_data=True) buckets = [] grouped_segments = [] transcripts_len = [] speakers_len = [] if self.mode == AudioTextDataset.DEFAULT_MODE: groupped_transcripts = ((i, [t]) for i, t in enumerate(segments)) else: groupped_transcripts = itertools.groupby( sorted(segments, key=transcripts.group_key), transcripts.group_key) for group_key, transcript in groupped_transcripts: transcript = sorted(transcript, key=transcripts.sort_key) if self.mode == AudioTextDataset.BATCHED_CHANNELS_MODE: transcript = transcripts.join_transcript( transcript, self.mono, duration_from_transcripts=duration_from_transcripts) if exclude is not None: allowed_audio_names = set( transcripts.audio_name(t) for t in transcript if transcripts.audio_name(t) not in exclude) else: allowed_audio_names = None transcript = transcripts.prune( transcript, allowed_audio_names=allowed_audio_names, duration=( min_duration if min_duration is not None else 0.0, max_duration if max_duration is not None else 24.0 * 3600, ), #24h max_audio_file_size=max_audio_file_size) transcript = list(transcript) for t in transcript: t['example_id'] = AudioTextDataset.get_example_id(t) if len(transcript) == 0: continue bucket = bucket_fn(transcript) for t in transcript: t['bucket'] = bucket speakers_len.append( len(t['speaker']) if ( isinstance(t['speaker'], list)) else 1) buckets.append(bucket) grouped_segments.extend(transcript) transcripts_len.append(len(transcript)) _print('Dataset construction time: ', time.time() - tic) tic = time.time() self.bucket = torch.tensor(buckets, dtype=torch.short) self.audio_path = utils.TensorBackedStringArray( [t['audio_path'] for t in grouped_segments], encoding=string_array_encoding) self.ref = utils.TensorBackedStringArray( [t['ref'] for t in grouped_segments], encoding=string_array_encoding) self.begin = torch.tensor([t['begin'] for t in grouped_segments], dtype=torch.float64) self.end = torch.tensor([t['end'] for t in grouped_segments], dtype=torch.float64) self.channel = torch.tensor([t['channel'] for t in grouped_segments], dtype=torch.int8) self.example_id = utils.TensorBackedStringArray( [t['example_id'] for t in grouped_segments], encoding=string_array_encoding) if self.mode == AudioTextDataset.BATCHED_CHANNELS_MODE: self.speaker = torch.tensor([ speaker for t in grouped_segments for speaker in t['speaker'] ], dtype=torch.int64) else: self.speaker = torch.tensor( [t['speaker'] for t in grouped_segments], dtype=torch.int64) self.speaker_len = torch.tensor(speakers_len, dtype=torch.int16) self.transcript_cumlen = torch.tensor(transcripts_len, dtype=torch.int16).cumsum( dim=0, dtype=torch.int64) if pop_meta: self.meta = {} else: self.meta = {t['example_id']: t for t in grouped_segments} _print('Dataset tensors creation time: ', time.time() - tic)