Exemple #1
0
def fmt_svg_speaker_barcode(transcript,
                            begin,
                            end,
                            colors=speaker_colors,
                            max_segment_seconds=60,
                            onclick=None):
    if onclick is None:
        onclick = 'onclick_svg(event)'
    color = lambda s: colors[s] if s < len(colors
                                           ) else transcripts.speaker_missing
    html = ''

    segments = transcripts.segment_by_time(
        transcript,
        max_segment_seconds=max_segment_seconds,
        break_on_speaker_change=False,
        break_on_channel_change=False)

    for segment in segments:
        summary = transcripts.summary(segment)
        duration = transcripts.compute_duration(summary)
        if duration <= max_segment_seconds:
            duration = max_segment_seconds
        header = '<div style="width: 100%; height: 15px; border: 1px black solid"><svg viewbox="0 0 1 1" style="width:100%; height:100%" preserveAspectRatio="none">'
        body = '\n'.join(
            '<rect data-begin="{begin}" data-end="{end}" x="{x}" width="{width}" height="1" style="fill:{color}" onclick="{onclick}"><title>speaker{speaker} | {begin:.2f} - {end:.2f} [{duration:.2f}]</title></rect>'
            .format(onclick=onclick,
                    x=(t['begin'] - summary['begin']) / duration,
                    width=(t['end'] - t['begin']) / duration,
                    color=color(t['speaker']),
                    duration=transcripts.compute_duration(t),
                    **t) for t in transcript)
        footer = '</svg></div>'
        html += header + body + footer
    return html
Exemple #2
0
def filter_dataset(input_path, output_path, duration_in_hours, cer, seed):
    dataset = transcripts.load(input_path)

    random.seed(seed)
    random.shuffle(dataset)

    print('initial set hours: ',
          sum(transcripts.compute_duration(t, hours=True) for t in dataset),
          'hours')
    if cer:
        dataset = [e for e in dataset if e['cer'] <= cer]
        print(
            'after cer filtering hours: ',
            sum(transcripts.compute_duration(t, hours=True) for t in dataset),
            'hours')

    if duration_in_hours is not None:
        s = []
        set_duration = 0
        while set_duration <= duration_in_hours and len(dataset) > 0:
            t = dataset.pop()
            set_duration += transcripts.compute_duration(t, hours=True)
            s.append(t)
        dataset = s

    print('after duration filtering hours: ',
          sum(transcripts.compute_duration(t, hours=True) for t in dataset),
          'hours')
    print(output_path)
    transcripts.save(output_path, dataset)
Exemple #3
0
def du(input_path):
    transcript = json.load(open(input_path))
    print(
        input_path, int(os.path.getsize(input_path) // 1e6), 'Mb', '|',
        len(transcript) // 1000, 'K utt |',
        int(
            sum(transcripts.compute_duration(t)
                for t in transcript) / (60 * 60)), 'hours')
Exemple #4
0
def split(input_path, output_path, test_duration_in_hours,
          val_duration_in_hours, microval_duration_in_hours, old_microval_path,
          seed):
    transcripts_train = json.load(open(input_path))

    random.seed(seed)
    random.shuffle(transcripts_train)

    for t in transcripts_train:
        t.pop('alignment')
        t.pop('words')
        t['meta'].pop('words_hyp')
        t['meta'].pop('words_ref')

    if old_microval_path:
        old_microval = json.load(
            open(os.path.join(output_path, old_microval_path)))
        old_microval_pahts = set([e['audio_path'] for e in old_microval])
        transcripts_train = [
            e for e in transcripts_train
            if e['audio_path'] not in old_microval_pahts
        ]

    for set_name, duration in [('test', test_duration_in_hours),
                               ('val', val_duration_in_hours),
                               ('microval', microval_duration_in_hours)]:
        if duration is not None:
            print(set_name)
            s = []
            set_duration = 0
            while set_duration <= duration:
                t = transcripts_train.pop()
                set_duration += transcripts.compute_duration(t, hours=True)
                s.append(t)
            json.dump(s,
                      open(
                          os.path.join(
                              output_path,
                              os.path.basename(output_path) +
                              f'_{set_name}.json'), 'w'),
                      ensure_ascii=False,
                      sort_keys=True,
                      indent=2)

    json.dump(transcripts_train,
              open(
                  os.path.join(output_path,
                               os.path.basename(output_path) + '_train.json'),
                  'w'),
              ensure_ascii=False,
              sort_keys=True,
              indent=2)
Exemple #5
0
def diarization(diarization_transcript, html_path, debug_audio):
    with open(html_path, 'w') as html:
        html.write(
            '<html><head>' + meta_charset +
            '<style>.nowrap{white-space:nowrap} table {border-collapse:collapse} .border-hyp {border-bottom: 2px black solid}</style></head><body>\n'
        )
        html.write(f'<script>{play_script}</script>\n')
        html.write(f'<script>{onclick_img_script}</script>')
        html.write('<table>\n')
        html.write(
            '<tr><th>audio_name</th><th>duration</th><th>refhyp</th><th>ser</th><th>der</th><th>der_</th><th>audio</th><th>barcode</th></tr>\n'
        )
        avg = lambda l: sum(l) / len(l)
        html.write(
            '<tr class="border-hyp"><td>{num_files}</td><td>{total_duration:.02f}</td><td>avg</td><td>{avg_ser:.02f}</td><td>{avg_der:.02f}</td><td>{avg_der_:.02f}</td><td></td><td></td></tr>\n'
            .format(num_files=len(diarization_transcript),
                    total_duration=sum(
                        map(transcripts.compute_duration,
                            diarization_transcript)),
                    avg_ser=avg([t['ser'] for t in diarization_transcript]),
                    avg_der=avg([t['der'] for t in diarization_transcript]),
                    avg_der_=avg([t['der_'] for t in diarization_transcript])))
        for i, dt in enumerate(diarization_transcript):
            audio_html = fmt_audio(audio_path,
                                   channel=channel) if debug_audio else ''
            begin, end = 0.0, transcripts.compute_duration(dt)
            for refhyp in ['ref', 'hyp']:
                html.write(
                    '<tr class="border-{refhyp}"><td class="nowrap">{audio_name}</td><td>{end:.02f}</td><td>{refhyp}</td><td>{ser:.02f}</td><td>{der:.02f}</td><td>{der_:.02f}</td><td rospan="{rowspan}">{audio_html}</td><td>{barcode}</td></tr>\n'
                    .format(audio_name=dt['audio_name'],
                            audio_html=audio_html if refhyp == 'ref' else '',
                            rowspan=2 if refhyp == 'ref' else 1,
                            refhyp=refhyp,
                            end=end,
                            ser=dt['ser'],
                            der=dt['der'],
                            der_=dt['der_'],
                            barcode=fmt_img_speaker_barcode(
                                dt[refhyp],
                                begin=begin,
                                end=end,
                                onclick=None if debug_audio else '',
                                dataset=dict(channel=i))))

        html.write('</table></body></html>')
    return html_path
Exemple #6
0
def dump(by_group, splits, subset_name, gz=True):
    for split_name, transcript in by_group.items():
        input_path = os.path.join(
            splits, f'{subset_name}_{split_name}.json') + ('.gz' if gz else '')
        with (gzip.open(input_path, 'wt')
              if gz else open(input_path, 'w')) as f:
            json.dump(transcript,
                      f,
                      indent=2,
                      sort_keys=True,
                      ensure_ascii=False)
        print(
            input_path, '|', int(os.path.getsize(input_path) // 1e6), 'Mb',
            '|',
            len(transcript) // 1000, 'K utt |',
            int(
                sum(
                    transcripts.compute_duration(t, hours=True)
                    for t in transcript)), 'hours')
Exemple #7
0
def speaker_error(ref,
                  hyp,
                  num_speakers,
                  sample_rate=8000,
                  hyp_speaker_mapping=None,
                  ignore_silence_and_overlapped_speech=True):
    assert num_speakers == 2
    duration = transcripts.compute_duration(dict(ref=ref, hyp=hyp))
    ref_mask = speaker_mask(ref, num_speakers, duration, sample_rate)
    hyp_mask_ = speaker_mask(hyp, num_speakers, duration, sample_rate)

    print('duration', duration)
    vals = []
    for hyp_perm in ([[0, 1, 2], [0, 2, 1]]
                     if hyp_speaker_mapping is None else hyp_speaker_mapping):
        hyp_mask = hyp_mask_[hyp_perm]
        speaker_mismatch = (ref_mask[1] != hyp_mask[1]) | (ref_mask[2] !=
                                                           hyp_mask[2])
        if ignore_silence_and_overlapped_speech:
            silence_or_overlap_mask = ref_mask[1] == ref_mask[2]
            speaker_mismatch = speaker_mismatch[~silence_or_overlap_mask]

        confusion = (hyp_mask[1] & ref_mask[2] &
                     (~ref_mask[1])) | (hyp_mask[2] & ref_mask[1] &
                                        (~ref_mask[2]))
        false_alarm = (hyp_mask[1]
                       | hyp_mask[2]) & (~ref_mask[1]) & (~ref_mask[2])
        miss = (~hyp_mask[1]) & (~hyp_mask[2]) & (ref_mask[1] | ref_mask[2])
        total = ref_mask[1] | ref_mask[2]

        confusion, false_alarm, miss, total = [
            float(x.float().mean()) * duration
            for x in [confusion, false_alarm, miss, total]
        ]

        print('my', 'confusion', confusion, 'false_alarm', false_alarm, 'miss',
              miss, 'total', total)
        err = float(speaker_mismatch.float().mean())
        vals.append((err, hyp_perm))

    return min(vals)
Exemple #8
0
def main(args):
    utils.enable_jit_fusion()

    assert args.output_json or args.output_html or args.output_txt or args.output_csv, \
     'at least one of the output formats must be provided'
    os.makedirs(args.output_path, exist_ok=True)
    data_paths = [
        p for f in args.input_path
        for p in ([os.path.join(f, g)
                   for g in os.listdir(f)] if os.path.isdir(f) else [f])
        if os.path.isfile(p) and any(map(p.endswith, args.ext))
    ] + [
        p
        for p in args.input_path if any(map(p.endswith, ['.json', '.json.gz']))
    ]
    exclude = set([
        os.path.splitext(basename)[0]
        for basename in os.listdir(args.output_path)
        if basename.endswith('.json')
    ] if args.skip_processed else [])
    data_paths = [
        path for path in data_paths if os.path.basename(path) not in exclude
    ]

    labels, frontend, model, decoder = setup(args)
    val_dataset = datasets.AudioTextDataset(
        data_paths, [labels],
        args.sample_rate,
        frontend=None,
        segmented=True,
        mono=args.mono,
        time_padding_multiple=args.batch_time_padding_multiple,
        audio_backend=args.audio_backend,
        exclude=exclude,
        max_duration=args.transcribe_first_n_sec,
        join_transcript=args.join_transcript,
        string_array_encoding=args.dataset_string_array_encoding)
    num_examples = len(val_dataset)
    print('Examples count: ', num_examples)
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=None,
        collate_fn=val_dataset.collate_fn,
        num_workers=args.num_workers)
    csv_sep = dict(tab='\t', comma=',')[args.csv_sep]
    output_lines = []  # only used if args.output_csv is True

    oom_handler = utils.OomHandler(max_retries=args.oom_retries)
    for i, (meta, s, x, xlen, y, ylen) in enumerate(val_data_loader):
        print(f'Processing: {i}/{num_examples}')

        meta = [val_dataset.meta.get(m['example_id']) for m in meta]
        audio_path = meta[0]['audio_path']

        if x.numel() == 0:
            print(f'Skipping empty [{audio_path}].')
            continue

        begin = meta[0]['begin']
        end = meta[0]['end']
        audio_name = transcripts.audio_name(audio_path)

        try:
            tic = time.time()
            y, ylen = y.to(args.device), ylen.to(args.device)
            log_probs, olen = model(
                x.squeeze(1).to(args.device), xlen.to(args.device))

            #speech = vad.detect_speech(x.squeeze(1), args.sample_rate, args.window_size, aggressiveness = args.vad, window_size_dilate = args.window_size_dilate)
            #speech = vad.upsample(speech, log_probs)
            #log_probs.masked_fill_(models.silence_space_mask(log_probs, speech, space_idx = labels.space_idx, blank_idx = labels.blank_idx), float('-inf'))

            decoded = decoder.decode(log_probs, olen)

            print('Input:', audio_name)
            print('Input time steps:', log_probs.shape[-1],
                  '| target time steps:', y.shape[-1])
            print(
                'Time: audio {audio:.02f} sec | processing {processing:.02f} sec'
                .format(audio=sum(
                    transcripts.compute_duration(t) for t in meta),
                        processing=time.time() - tic))

            ts = (x.shape[-1] / args.sample_rate) * torch.linspace(
                0, 1,
                steps=log_probs.shape[-1]).unsqueeze(0) + torch.FloatTensor(
                    [t['begin'] for t in meta]).unsqueeze(1)
            channel = [t['channel'] for t in meta]
            speaker = [t['speaker'] for t in meta]
            ref_segments = [[
                dict(channel=channel[i],
                     begin=meta[i]['begin'],
                     end=meta[i]['end'],
                     ref=labels.decode(y[i, 0, :ylen[i]].tolist()))
            ] for i in range(len(decoded))]
            hyp_segments = [
                labels.decode(decoded[i],
                              ts[i],
                              channel=channel[i],
                              replace_blank=True,
                              replace_blank_series=args.replace_blank_series,
                              replace_repeat=True,
                              replace_space=False,
                              speaker=speaker[i] if isinstance(
                                  speaker[i], str) else None)
                for i in range(len(decoded))
            ]

            ref, hyp = '\n'.join(
                transcripts.join(ref=r)
                for r in ref_segments).strip(), '\n'.join(
                    transcripts.join(hyp=h) for h in hyp_segments).strip()
            if args.verbose:
                print('HYP:', hyp)
            print('CER: {cer:.02%}'.format(cer=metrics.cer(hyp=hyp, ref=ref)))

            tic_alignment = time.time()
            if args.align and y.numel() > 0:
                #if ref_full:# and not ref:
                #	#assert len(set(t['channel'] for t in meta)) == 1 or all(t['type'] != 'channel' for t in meta)
                #	#TODO: add space at the end
                #	channel = torch.ByteTensor(channel).repeat_interleave(log_probs.shape[-1]).reshape(1, -1)
                #	ts = ts.reshape(1, -1)
                #	log_probs = log_probs.transpose(0, 1).unsqueeze(0).flatten(start_dim = -2)
                #	olen = torch.tensor([log_probs.shape[-1]], device = log_probs.device, dtype = torch.long)
                #	y = y_full[None, None, :].to(y.device)
                #	ylen = torch.tensor([[y.shape[-1]]], device = log_probs.device, dtype = torch.long)
                #	segments = [([], sum([h for r, h in segments], []))]

                alignment = ctc.alignment(
                    log_probs.permute(2, 0, 1),
                    y.squeeze(1),
                    olen,
                    ylen.squeeze(1),
                    blank=labels.blank_idx,
                    pack_backpointers=args.pack_backpointers)
                ref_segments = [
                    labels.decode(y[i, 0, :ylen[i]].tolist(),
                                  ts[i],
                                  alignment[i],
                                  channel=channel[i],
                                  speaker=speaker[i],
                                  key='ref',
                                  speakers=val_dataset.speakers)
                    for i in range(len(decoded))
                ]
            oom_handler.reset()
        except:
            if oom_handler.try_recover(model.parameters()):
                print(f'Skipping {i} / {num_examples}')
                continue
            else:
                raise

        print('Alignment time: {:.02f} sec'.format(time.time() -
                                                   tic_alignment))

        if args.max_segment_duration:
            ref_transcript, hyp_transcript = [
                list(sorted(sum(segments, []), key=transcripts.sort_key))
                for segments in [ref_segments, hyp_segments]
            ]
            if ref:
                ref_segments = list(
                    transcripts.segment(ref_transcript,
                                        args.max_segment_duration))
                hyp_segments = list(
                    transcripts.segment(hyp_transcript, ref_segments))
            else:
                hyp_segments = list(
                    transcripts.segment(hyp_transcript,
                                        args.max_segment_duration))
                ref_segments = [[] for _ in hyp_segments]

        transcript = [
            dict(audio_path=audio_path,
                 ref=ref,
                 hyp=hyp,
                 speaker=transcripts.speaker(ref=ref_transcript,
                                             hyp=hyp_transcript),
                 cer=metrics.cer(hyp=hyp, ref=ref),
                 words=metrics.align_words(hyp=hyp, ref=ref)[-1]
                 if args.align_words else [],
                 alignment=dict(ref=ref_transcript, hyp=hyp_transcript),
                 **transcripts.summary(hyp_transcript)) for ref_transcript,
            hyp_transcript in zip(ref_segments, hyp_segments)
            for ref, hyp in [(transcripts.join(ref=ref_transcript),
                              transcripts.join(hyp=hyp_transcript))]
        ]
        filtered_transcript = list(
            transcripts.prune(transcript,
                              align_boundary_words=args.align_boundary_words,
                              cer=args.cer,
                              duration=args.duration,
                              gap=args.gap,
                              unk=args.unk,
                              num_speakers=args.num_speakers))

        print('Filtered segments:', len(filtered_transcript), 'out of',
              len(transcript))

        if args.output_json:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.json')
            print(transcript_path)
            with open(transcript_path, 'w') as f:
                json.dump(filtered_transcript,
                          f,
                          ensure_ascii=False,
                          sort_keys=True,
                          indent=2)

        if args.output_html:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.html')
            print(transcript_path)
            vis.transcript(transcript_path, args.sample_rate, args.mono,
                           transcript, filtered_transcript)

        if args.output_txt:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.txt')
            print(transcript_path)
            with open(transcript_path, 'w') as f:
                f.write(hyp)

        if args.output_csv:
            output_lines.append(
                csv_sep.join((audio_path, hyp, str(begin), str(end))) + '\n')

        print('Done: {:.02f} sec\n'.format(time.time() - tic))

    if args.output_csv:
        with open(os.path.join(args.output_path, 'transcripts.csv'), 'w') as f:
            f.writelines(output_lines)