Example #1
0
def filter_dataset(input_path, output_path, duration_in_hours, cer, seed):
    dataset = transcripts.load(input_path)

    random.seed(seed)
    random.shuffle(dataset)

    print('initial set hours: ',
          sum(transcripts.compute_duration(t, hours=True) for t in dataset),
          'hours')
    if cer:
        dataset = [e for e in dataset if e['cer'] <= cer]
        print(
            'after cer filtering hours: ',
            sum(transcripts.compute_duration(t, hours=True) for t in dataset),
            'hours')

    if duration_in_hours is not None:
        s = []
        set_duration = 0
        while set_duration <= duration_in_hours and len(dataset) > 0:
            t = dataset.pop()
            set_duration += transcripts.compute_duration(t, hours=True)
            s.append(t)
        dataset = s

    print('after duration filtering hours: ',
          sum(transcripts.compute_duration(t, hours=True) for t in dataset),
          'hours')
    print(output_path)
    transcripts.save(output_path, dataset)
Example #2
0
def eval(ref, hyp, html, debug_audio, sample_rate=100):
    if os.path.isfile(ref) and os.path.isfile(hyp):
        print(der(ref_rttm_path=ref, hyp_rttm_path=hyp))

    elif os.path.isdir(ref) and os.path.isdir(hyp):
        errs = []
        diarization_transcript = []
        for rttm in os.listdir(ref):
            if not rttm.endswith('.rttm'):
                continue

            print(rttm)
            audio_path = transcripts.load(
                os.path.join(hyp, rttm).replace('.rttm',
                                                '.json'))[0]['audio_path']

            ref_rttm_path, hyp_rttm_path = os.path.join(ref,
                                                        rttm), os.path.join(
                                                            hyp, rttm)
            ref_transcript, hyp_transcript = map(
                transcripts.load, [ref_rttm_path, hyp_rttm_path])
            ser_err, hyp_perm = speaker_error(
                ref=ref_transcript,
                hyp=hyp_transcript,
                num_speakers=2,
                sample_rate=sample_rate,
                ignore_silence_and_overlapped_speech=True)
            der_err, *_ = speaker_error(
                ref=ref_transcript,
                hyp=hyp_transcript,
                num_speakers=2,
                sample_rate=sample_rate,
                ignore_silence_and_overlapped_speech=False)
            der_err_ = der(ref_rttm_path=ref_rttm_path,
                           hyp_rttm_path=hyp_rttm_path)
            transcripts.remap_speaker(hyp_transcript, hyp_perm)

            err = dict(ser=ser_err, der=der_err, der_=der_err_)
            diarization_transcript.append(
                dict(audio_path=audio_path,
                     audio_name=transcripts.audio_name(audio_path),
                     ref=ref_transcript,
                     hyp=hyp_transcript,
                     **err))
            print(rttm, '{ser:.2f}, {der:.2f} | {der_:.2f}'.format(**err))
            print()
            errs.append(err)
        print('===')
        print({
            k: sum(e) / len(e)
            for k in errs[0] for e in [[err[k] for err in errs]]
        })

        if html:
            print(
                vis.diarization(
                    sorted(diarization_transcript,
                           key=lambda t: t['ser'],
                           reverse=True), html, debug_audio))
Example #3
0
def main(args, ext_json=['.json', '.json.gz']):
    utils.enable_jit_fusion()

    assert args.output_json or args.output_html or args.output_txt or args.output_csv, \
     'at least one of the output formats must be provided'
    os.makedirs(args.output_path, exist_ok=True)

    audio_data_paths = set(
        p for f in args.input_path
        for p in ([os.path.join(f, g)
                   for g in os.listdir(f)] if os.path.isdir(f) else [f])
        if os.path.isfile(p) and any(map(p.endswith, args.ext)))
    json_data_paths = set(
        p for p in args.input_path if any(map(p.endswith, ext_json))
        and not utils.strip_suffixes(p, ext_json) in audio_data_paths)

    data_paths = list(audio_data_paths | json_data_paths)

    exclude = set([
        os.path.splitext(basename)[0]
        for basename in os.listdir(args.output_path)
        if basename.endswith('.json')
    ]) if args.skip_processed else None

    data_paths = [
        path for path in data_paths
        if exclude is None or os.path.basename(path) not in exclude
    ]

    text_pipeline, frontend, model, generator = setup(args)
    val_dataset = datasets.AudioTextDataset(
        data_paths, [text_pipeline],
        args.sample_rate,
        frontend=frontend if not args.frontend_in_model else None,
        mono=args.mono,
        time_padding_multiple=args.batch_time_padding_multiple,
        audio_backend=args.audio_backend,
        exclude=exclude,
        max_duration=args.transcribe_first_n_sec,
        mode='batched_channels'
        if args.join_transcript else 'batched_transcript',
        string_array_encoding=args.dataset_string_array_encoding,
        debug_short_long_records_features_from_whole_normalized_signal=args.
        debug_short_long_records_features_from_whole_normalized_signal)
    print('Examples count: ', len(val_dataset))
    val_meta = val_dataset.pop_meta()
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=None,
        collate_fn=val_dataset.collate_fn,
        num_workers=args.num_workers)
    csv_sep = dict(tab='\t', comma=',')[args.csv_sep]
    csv_lines = []  # only used if args.output_csv is True

    oom_handler = utils.OomHandler(max_retries=args.oom_retries)
    for i, (meta, s, x, xlen, y, ylen) in enumerate(val_data_loader):
        print(f'Processing: {i}/{len(val_dataset)}')
        meta = [val_meta[t['example_id']] for t in meta]

        audio_path = meta[0]['audio_path']
        audio_name = transcripts.audio_name(audio_path)
        begin_end = [dict(begin=t['begin'], end=t['end']) for t in meta]
        begin = torch.tensor([t['begin'] for t in begin_end],
                             dtype=torch.float)
        end = torch.tensor([t['end'] for t in begin_end], dtype=torch.float)
        #TODO WARNING assumes frontend not in dataset
        if not args.frontend_in_model:
            print('\n' * 10 + 'WARNING\n' * 5)
            print(
                'transcribe.py assumes frontend in model, in other case time alignment was incorrect'
            )
            print('WARNING\n' * 5 + '\n')

        duration = x.shape[-1] / args.sample_rate
        channel = [t['channel'] for t in meta]
        speaker = [t['speaker'] for t in meta]
        speaker_name = [t['speaker_name'] for t in meta]

        if x.numel() == 0:
            print(f'Skipping empty [{audio_path}].')
            continue

        try:
            tic = time.time()
            y, ylen = y.to(args.device), ylen.to(args.device)
            log_probs, logits, olen = model(
                x.squeeze(1).to(args.device), xlen.to(args.device))

            print('Input:', audio_name)
            print('Input time steps:', log_probs.shape[-1],
                  '| target time steps:', y.shape[-1])
            print(
                'Time: audio {audio:.02f} sec | processing {processing:.02f} sec'
                .format(audio=sum(map(transcripts.compute_duration, meta)),
                        processing=time.time() - tic))

            ts: shaping.Bt = duration * torch.linspace(
                0, 1, steps=log_probs.shape[-1],
                device=log_probs.device).unsqueeze(0).expand(x.shape[0], -1)

            ref_segments = [[
                dict(channel=channel[i],
                     begin=begin_end[i]['begin'],
                     end=begin_end[i]['end'],
                     ref=text_pipeline.postprocess(
                         text_pipeline.preprocess(meta[i]['ref'])))
            ] for i in range(len(meta))]
            hyp_segments = [
                alternatives[0] for alternatives in generator.generate(
                    tokenizer=text_pipeline.tokenizer,
                    log_probs=log_probs,
                    begin=begin,
                    end=end,
                    output_lengths=olen,
                    time_stamps=ts,
                    segment_text_key='hyp',
                    segment_extra_info=[
                        dict(speaker=s, speaker_name=sn, channel=c)
                        for s, sn, c in zip(speaker, speaker_name, channel)
                    ])
            ]
            hyp_segments = [
                transcripts.map_text(text_pipeline.postprocess, hyp=hyp)
                for hyp in hyp_segments
            ]
            hyp, ref = '\n'.join(
                transcripts.join(hyp=h)
                for h in hyp_segments).strip(), '\n'.join(
                    transcripts.join(ref=r) for r in ref_segments).strip()
            if args.verbose:
                print('HYP:', hyp)
            print('CER: {cer:.02%}'.format(cer=metrics.cer(hyp=hyp, ref=ref)))

            tic_alignment = time.time()
            if args.align and y.numel() > 0:
                alignment: shaping.BY = ctc.alignment(
                    log_probs.permute(2, 0, 1),
                    y[:, 0, :],  # assumed that 0 channel is char labels
                    olen,
                    ylen[:, 0],
                    blank=text_pipeline.tokenizer.eps_id,
                    pack_backpointers=args.pack_backpointers)
                aligned_ts: shaping.Bt = ts.gather(1, alignment)

                ref_segments = [
                    alternatives[0] for alternatives in generator.generate(
                        tokenizer=text_pipeline.tokenizer,
                        log_probs=torch.nn.functional.one_hot(
                            y[:,
                              0, :], num_classes=log_probs.shape[1]).permute(
                                  0, 2, 1),
                        begin=begin,
                        end=end,
                        output_lengths=ylen,
                        time_stamps=aligned_ts,
                        segment_text_key='ref',
                        segment_extra_info=[
                            dict(speaker=s, speaker_name=sn, channel=c)
                            for s, sn, c in zip(speaker, speaker_name, channel)
                        ])
                ]
                ref_segments = [
                    transcripts.map_text(text_pipeline.postprocess, hyp=ref)
                    for ref in ref_segments
                ]
            oom_handler.reset()
        except:
            if oom_handler.try_recover(model.parameters()):
                print(f'Skipping {i} / {len(val_dataset)}')
                continue
            else:
                raise

        print('Alignment time: {:.02f} sec'.format(time.time() -
                                                   tic_alignment))

        ref_transcript, hyp_transcript = [
            sorted(transcripts.flatten(segments), key=transcripts.sort_key)
            for segments in [ref_segments, hyp_segments]
        ]

        if args.max_segment_duration:
            if ref:
                ref_segments = list(
                    transcripts.segment_by_time(ref_transcript,
                                                args.max_segment_duration))
                hyp_segments = list(
                    transcripts.segment_by_ref(hyp_transcript, ref_segments))
            else:
                hyp_segments = list(
                    transcripts.segment_by_time(hyp_transcript,
                                                args.max_segment_duration))
                ref_segments = [[] for _ in hyp_segments]

        #### HACK for diarization
        elif args.ref_transcript_path and args.join_transcript:
            audio_name_hack = audio_name.split('.')[0]
            #TODO: normalize ref field
            ref_segments = [[t] for t in sorted(transcripts.load(
                os.path.join(args.ref_transcript_path, audio_name_hack +
                             '.json')),
                                                key=transcripts.sort_key)]
            hyp_segments = list(
                transcripts.segment_by_ref(hyp_transcript,
                                           ref_segments,
                                           set_speaker=True,
                                           soft=False))
        #### END OF HACK

        has_ref = bool(transcripts.join(ref=transcripts.flatten(ref_segments)))

        transcript = []
        for hyp_transcript, ref_transcript in zip(hyp_segments, ref_segments):
            hyp, ref = transcripts.join(hyp=hyp_transcript), transcripts.join(
                ref=ref_transcript)

            transcript.append(
                dict(audio_path=audio_path,
                     ref=ref,
                     hyp=hyp,
                     speaker_name=transcripts.speaker_name(ref=ref_transcript,
                                                           hyp=hyp_transcript),
                     words=metrics.align_words(
                         *metrics.align_strings(hyp=hyp, ref=ref))
                     if args.align_words else [],
                     words_ref=ref_transcript,
                     words_hyp=hyp_transcript,
                     **transcripts.summary(hyp_transcript),
                     **(dict(cer=metrics.cer(hyp=hyp, ref=ref))
                        if has_ref else {})))

        transcripts.collect_speaker_names(transcript,
                                          set_speaker_data=True,
                                          num_speakers=2)

        filtered_transcript = list(
            transcripts.prune(transcript,
                              align_boundary_words=args.align_boundary_words,
                              cer=args.prune_cer,
                              duration=args.prune_duration,
                              gap=args.prune_gap,
                              allowed_unk_count=args.prune_unk,
                              num_speakers=args.prune_num_speakers))

        print('Filtered segments:', len(filtered_transcript), 'out of',
              len(transcript))

        if args.output_json:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.json')
            print(transcripts.save(transcript_path, filtered_transcript))

        if args.output_html:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.html')
            print(
                vis.transcript(transcript_path, args.sample_rate, args.mono,
                               transcript, filtered_transcript))

        if args.output_txt:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.txt')
            with open(transcript_path, 'w') as f:
                f.write(' '.join(t['hyp'].strip()
                                 for t in filtered_transcript))
            print(transcript_path)

        if args.output_csv:
            assert len({t['audio_path'] for t in filtered_transcript}) == 1
            audio_path = filtered_transcript[0]['audio_path']
            hyp = ' '.join(t['hyp'].strip() for t in filtered_transcript)
            begin = min(t['begin'] for t in filtered_transcript)
            end = max(t['end'] for t in filtered_transcript)
            csv_lines.append(
                csv_sep.join([audio_path, hyp,
                              str(begin),
                              str(end)]))

        if args.logits:
            logits_file_path = os.path.join(args.output_path,
                                            audio_name + '.pt')
            if args.logits_crop:
                begin_end = [
                    dict(
                        zip(['begin', 'end'], [
                            t['begin'] + c / float(o) * (t['end'] - t['begin'])
                            for c in args.logits_crop
                        ])) for o, t in zip(olen, begin_end)
                ]
                logits_crop = [slice(*args.logits_crop) for o in olen]
            else:
                logits_crop = [slice(int(o)) for o in olen]

            # TODO: filter ref / hyp by channel?
            torch.save([
                dict(audio_path=audio_path,
                     logits=l[..., logits_crop[i]],
                     **begin_end[i],
                     ref=ref,
                     hyp=hyp) for i, l in enumerate(logits.cpu())
            ], logits_file_path)
            print(logits_file_path)

        print('Done: {:.02f} sec\n'.format(time.time() - tic))

    if args.output_csv:
        transcript_path = os.path.join(args.output_path, 'transcripts.csv')
        with open(transcript_path, 'w') as f:
            f.write('\n'.join(csv_lines))
        print(transcript_path)
Example #4
0
def errors(
	input_paths: typing.List[str],
	output_path: typing.Optional[str] = None,
	include_metrics: typing.List[str] = ('cer', 'wer',),
	debug_audio: bool = False,
	filter_fn: typing.Optional[typing.Callable[[typing.Tuple[dict]], bool]] = lambda x: True,
	sort_fn: typing.Optional[typing.Callable[[typing.List[typing.Tuple[dict]]], typing.List[typing.Tuple[dict]]]] = lambda x: x
) -> str:
	'''
	Parameters:
		input_paths: paths to json files with list of analyzed examples
		output_path: path to output html (default: input_path[0]+.html)
		debug_audio: include audio data into html if true
		filter_fn: function to filter tuples of examples grouped by `audio_path`,
				   function input: tuple of examples in order same as in `input_paths`
				   function output: true to include examples into html, false otherwise
		sort_fn: function to sort tuples of examples grouped by `audio_path`,
				 function input: list of tuples of examples, each tuple has same order as in `input_paths`
				 function output: same list but in sorted order
	'''
	grouped_examples = collections.defaultdict(list)
	examples_count = {}
	for path in input_paths:
		examples = transcripts.load(path)
		examples_count[path] = len(examples)
		for example in examples:
			grouped_examples[example['audio_path']].append(example)
	grouped_examples = list(filter(lambda x: len(x) == len(input_paths), grouped_examples.values()))
	not_found_examples_count = {path: count - len(grouped_examples) for path, count in examples_count.items()}
	grouped_examples = list(filter(filter_fn, grouped_examples))
	filtered_examples_count = {path: count - len(grouped_examples) - not_found_examples_count[path] for path, count in examples_count.items()}
	grouped_examples = sort_fn(grouped_examples)
	style = '''
				.filters_table b.warning {color: red;}
		        table.metrics_table {border-collapse:collapse;}
		        .metrics_table th {padding: 5px; padding-left: 10px; text-align: left}
		        .metrics_table tr {padding: 5px;}
		        .metrics_table tr.new_section {border-top: 1px solid black; padding: 5px;}
		        .metrics_table td {border-left: 1px dashed black; border-right: 1px dashed black; padding: 5px; padding-left: 10px;}   
	'''

	template = '''
		<html>
		<head>
		    <meta charset="utf-8">
		    <style>
		        {style}
		    </style>
		    <script>
		        {scripts}
		    </script>
		</head>
		<body>
			<b style="padding: 10px">Filters</b><br><br>
            Dropped (example not found in other files):<br>
            <table class="filters_table">
		        {filter_not_found_table}
		    </table><br>
		    Dropped (filter_fn):
		    <table class="filters_table">
		        {filter_fn_table}
		    </table><br>
		    <table class="metrics_table">
		        {metrics_table}
		    </table>
		</body>
		</html>
	'''

	# Make filter "not found" table
	def fmt_filter_table(filtered_count: dict) -> str:
		filtered_table = []
		for file_path, count in filtered_count.items():
			css_class = 'warning' if count > 0 else ''
			file_name = os.path.basename(file_path)
			filtered_table.append(f'<tr><td>{file_name}</td><td><b class="{css_class}">{count}</b></td></tr>')
		return '\n'.join(filtered_table)
	filter_not_found_table = fmt_filter_table(not_found_examples_count)

	# Make filter "filter_fn" table
	filter_fn_table = fmt_filter_table(filtered_examples_count)

	# Make averages table
	def fmt_averages_table(include_metrics: typing.List[str], averages: dict) -> str:
		header = '<tr><th>Averages</th>' + '<th></th>' * (len(include_metrics) + 2) + '</tr>\n'
		header += '<tr><th></th>' + ''.join(f'<th>{metric_name}</th>' for metric_name in include_metrics) + '<th></th>' * 2 + '</tr>\n'
		content = []
		for i, (file_name, metric_values) in enumerate(averages.items()):
			content_line = f'<td><b>{file_name}</b></td>' + ''.join(
				f'<td>{metric_value:.2%}</td>' for metric_value in metric_values) + '<td></td>' * 2
			if i == 0:
				content_line = '<tr class="new_section">' + content_line + '</tr>'
			else:
				content_line = '<tr>' + content_line + '</tr>'
			content.append(content_line)
		content = '\n'.join(content)
		footer = '<tr class="new_section" style="height: 30px">' + '<th></th>' * (len(include_metrics) + 3) + '</tr>\n'
		return header + content + footer

	averages = {}
	for i, input_file in enumerate(input_paths):
		file_name = os.path.basename(input_file)
		file_examples = [examples[i] for examples in grouped_examples]
		averages[file_name] = [metrics.nanmean(file_examples, metric_name) for metric_name in include_metrics]
	average_table = fmt_averages_table(include_metrics, averages)

	# Make examples table
	def fmt_examples_table(include_metrics: typing.List[str], table_data: typing.List[dict], debug_audio: bool) -> str:
		header = '<tr><th>Examples</th>' + '<th></th>' * (len(include_metrics) + 2) + '</tr>\n'
		content = []
		for i, examples_data in enumerate(table_data):
			ref = '<pre>' + examples_data['ref'] + '</pre>'
			audio_path = examples_data['audio_path']
			embedded_audio = fmt_audio(audio_path, i) if debug_audio else ''
			examples_header = f'<tr class="new_section"><td colspan="{len(include_metrics)+1}"><b>{i}.</b>{audio_path}</td><td>{embedded_audio}</td><td>ref: <pre>{ref}</pre></td></tr>'
			examples_content = []
			for i, example_data in enumerate(examples_data['examples']):
				metric_values = [f'{value:.2%}' if value is not None else '-' for value in example_data['metric_values']]
				file_name = example_data['file_name']
				alignment = example_data['alignment']
				hyp = '<pre>' + example_data['hyp'] + '</pre>'
				content_line = (f'<td>{file_name}</td>' + ''.join(map('<td>{}</td>'.format, metric_values)) + f'<td>{alignment}</td><td>{hyp}</td>')
				if i == 0:
					examples_content.append('<tr class="new_section">' + content_line + '</tr>')
				else:
					examples_content.append('<tr>' + content_line + '</tr>')
			content.append(examples_header)
			content.extend(examples_content)
		return header + '\n'.join(content)

	table_data = []
	for examples in grouped_examples:
		examples_data = dict(
			audio_path = examples[0]['audio_path'],
			ref = examples[0]['ref_orig'],
			examples = [])
		for i, input_file in enumerate(input_paths):
			examples_data['examples'].append(dict(
				file_name = os.path.basename(input_file),
				metric_values = [metrics.extract_metric_value(examples[i], metric_name) for metric_name in include_metrics],
				alignment = fmt_alignment(examples[i]['alignment']),
				hyp = examples[i]["hyp"]))
		table_data.append(examples_data)

	examples_data = fmt_examples_table(include_metrics, table_data, debug_audio)

	# make output html
	metrics_table = average_table + examples_data
	report = template.format(style = style,
	                         scripts = play_script if debug_audio else '',
	                         filter_not_found_table = filter_not_found_table,
	                         filter_fn_table = filter_fn_table,
	                         metrics_table = metrics_table)
	html_path = output_path or (input_paths[0] + '.html')
	open(html_path, 'w').write(report)
	return html_path
Example #5
0
    def __init__(
        self,
        data_paths: typing.List[str],
        text_pipelines: typing.List[text_processing.ProcessingPipeline],
        sample_rate: int,
        mode: str = DEFAULT_MODE,
        frontend: typing.Optional[torch.nn.Module] = None,
        speaker_names: typing.Optional[typing.List[str]] = None,
        max_audio_file_size: typing.Optional[float] = None,  #bytes
        min_duration: typing.Optional[float] = None,
        max_duration: typing.Optional[float] = None,
        max_num_channels: int = 2,
        mono: bool = True,
        audio_dtype: str = 'float32',
        time_padding_multiple: int = 1,
        audio_backend: typing.Optional[str] = None,
        exclude: typing.Optional[typing.Set] = None,
        bucket_fn: typing.Callable[[typing.List[typing.Dict]],
                                   int] = lambda transcript: 0,
        pop_meta: bool = False,
        string_array_encoding: str = 'utf_16_le',
        _print: typing.Callable = print,
        debug_short_long_records_features_from_whole_normalized_signal:
        bool = False,
        duration_from_transcripts: bool = False,
    ):
        self.debug_short_long_records_features_from_whole_normalized_signal = debug_short_long_records_features_from_whole_normalized_signal
        self.mode = mode
        self.min_duration = min_duration
        self.max_duration = max_duration
        self.max_audio_file_size = max_audio_file_size
        self.text_pipelines = text_pipelines
        self.frontend = frontend
        self.sample_rate = sample_rate
        self.time_padding_multiple = time_padding_multiple
        self.mono = mono
        self.audio_backend = audio_backend
        self.audio_dtype = audio_dtype

        data_paths = data_paths if isinstance(data_paths,
                                              list) else [data_paths]

        data_paths_ = []
        for data_path in data_paths:
            if os.path.isdir(data_path):
                data_paths_.extend(
                    os.path.join(data_path, filename) for filename in filter(
                        audio.is_audio, os.listdir(data_path)))
            else:
                data_paths_.append(data_path)
        data_paths = data_paths_

        tic = time.time()

        segments = []
        for path in data_paths:
            if audio.is_audio(path):
                assert self.mono or self.mode != AudioTextDataset.DEFAULT_MODE, 'Only mono audio files allowed as dataset input in default mode'
                if self.mono:
                    transcript = [
                        dict(audio_path=path,
                             channel=transcripts.channel_missing)
                    ]
                else:
                    transcript = [
                        dict(audio_path=path, channel=c)
                        for c in range(max_num_channels)
                    ]
            else:
                transcript = transcripts.load(path)
            segments.extend(transcript)

        _print('Dataset reading time: ', time.time() - tic)
        tic = time.time()

        # get_or_else required because dictionary could contain None values which we want to replace.
        # dict.get doesnt work in this case
        get_or_else = lambda dictionary, key, default: dictionary[
            key] if dictionary.get(key) is not None else default
        for t in segments:
            t['ref'] = get_or_else(t, 'ref', transcripts.ref_missing)
            t['begin'] = get_or_else(t, 'begin', transcripts.time_missing)
            t['end'] = get_or_else(t, 'end', transcripts.time_missing)
            t['channel'] = get_or_else(
                t, 'channel', transcripts.channel_missing
            ) if not self.mono else transcripts.channel_missing

        transcripts.collect_speaker_names(segments,
                                          speaker_names=speaker_names or [],
                                          num_speakers=max_num_channels,
                                          set_speaker_data=True)

        buckets = []
        grouped_segments = []
        transcripts_len = []
        speakers_len = []
        if self.mode == AudioTextDataset.DEFAULT_MODE:
            groupped_transcripts = ((i, [t]) for i, t in enumerate(segments))
        else:
            groupped_transcripts = itertools.groupby(
                sorted(segments, key=transcripts.group_key),
                transcripts.group_key)

        for group_key, transcript in groupped_transcripts:
            transcript = sorted(transcript, key=transcripts.sort_key)
            if self.mode == AudioTextDataset.BATCHED_CHANNELS_MODE:
                transcript = transcripts.join_transcript(
                    transcript,
                    self.mono,
                    duration_from_transcripts=duration_from_transcripts)

            if exclude is not None:
                allowed_audio_names = set(
                    transcripts.audio_name(t) for t in transcript
                    if transcripts.audio_name(t) not in exclude)
            else:
                allowed_audio_names = None

            transcript = transcripts.prune(
                transcript,
                allowed_audio_names=allowed_audio_names,
                duration=(
                    min_duration if min_duration is not None else 0.0,
                    max_duration if max_duration is not None else 24.0 * 3600,
                ),  #24h
                max_audio_file_size=max_audio_file_size)
            transcript = list(transcript)
            for t in transcript:
                t['example_id'] = AudioTextDataset.get_example_id(t)

            if len(transcript) == 0:
                continue

            bucket = bucket_fn(transcript)
            for t in transcript:
                t['bucket'] = bucket
                speakers_len.append(
                    len(t['speaker']) if (
                        isinstance(t['speaker'], list)) else 1)
            buckets.append(bucket)
            grouped_segments.extend(transcript)
            transcripts_len.append(len(transcript))

        _print('Dataset construction time: ', time.time() - tic)
        tic = time.time()

        self.bucket = torch.tensor(buckets, dtype=torch.short)
        self.audio_path = utils.TensorBackedStringArray(
            [t['audio_path'] for t in grouped_segments],
            encoding=string_array_encoding)
        self.ref = utils.TensorBackedStringArray(
            [t['ref'] for t in grouped_segments],
            encoding=string_array_encoding)
        self.begin = torch.tensor([t['begin'] for t in grouped_segments],
                                  dtype=torch.float64)
        self.end = torch.tensor([t['end'] for t in grouped_segments],
                                dtype=torch.float64)
        self.channel = torch.tensor([t['channel'] for t in grouped_segments],
                                    dtype=torch.int8)
        self.example_id = utils.TensorBackedStringArray(
            [t['example_id'] for t in grouped_segments],
            encoding=string_array_encoding)
        if self.mode == AudioTextDataset.BATCHED_CHANNELS_MODE:
            self.speaker = torch.tensor([
                speaker for t in grouped_segments for speaker in t['speaker']
            ],
                                        dtype=torch.int64)
        else:
            self.speaker = torch.tensor(
                [t['speaker'] for t in grouped_segments], dtype=torch.int64)
        self.speaker_len = torch.tensor(speakers_len, dtype=torch.int16)
        self.transcript_cumlen = torch.tensor(transcripts_len,
                                              dtype=torch.int16).cumsum(
                                                  dim=0, dtype=torch.int64)
        if pop_meta:
            self.meta = {}
        else:
            self.meta = {t['example_id']: t for t in grouped_segments}
        _print('Dataset tensors creation time: ', time.time() - tic)