Ejemplo n.º 1
0
def eval(ref, hyp, html, debug_audio, sample_rate=100):
    if os.path.isfile(ref) and os.path.isfile(hyp):
        print(der(ref_rttm_path=ref, hyp_rttm_path=hyp))

    elif os.path.isdir(ref) and os.path.isdir(hyp):
        errs = []
        diarization_transcript = []
        for rttm in os.listdir(ref):
            if not rttm.endswith('.rttm'):
                continue

            print(rttm)
            audio_path = transcripts.load(
                os.path.join(hyp, rttm).replace('.rttm',
                                                '.json'))[0]['audio_path']

            ref_rttm_path, hyp_rttm_path = os.path.join(ref,
                                                        rttm), os.path.join(
                                                            hyp, rttm)
            ref_transcript, hyp_transcript = map(
                transcripts.load, [ref_rttm_path, hyp_rttm_path])
            ser_err, hyp_perm = speaker_error(
                ref=ref_transcript,
                hyp=hyp_transcript,
                num_speakers=2,
                sample_rate=sample_rate,
                ignore_silence_and_overlapped_speech=True)
            der_err, *_ = speaker_error(
                ref=ref_transcript,
                hyp=hyp_transcript,
                num_speakers=2,
                sample_rate=sample_rate,
                ignore_silence_and_overlapped_speech=False)
            der_err_ = der(ref_rttm_path=ref_rttm_path,
                           hyp_rttm_path=hyp_rttm_path)
            transcripts.remap_speaker(hyp_transcript, hyp_perm)

            err = dict(ser=ser_err, der=der_err, der_=der_err_)
            diarization_transcript.append(
                dict(audio_path=audio_path,
                     audio_name=transcripts.audio_name(audio_path),
                     ref=ref_transcript,
                     hyp=hyp_transcript,
                     **err))
            print(rttm, '{ser:.2f}, {der:.2f} | {der_:.2f}'.format(**err))
            print()
            errs.append(err)
        print('===')
        print({
            k: sum(e) / len(e)
            for k in errs[0] for e in [[err[k] for err in errs]]
        })

        if html:
            print(
                vis.diarization(
                    sorted(diarization_transcript,
                           key=lambda t: t['ser'],
                           reverse=True), html, debug_audio))
Ejemplo n.º 2
0
    def __init__(
            self,
            data_paths,
            text_pipelines: typing.List[
                language_processing.ProcessingPipeline],
            sample_rate,
            frontend=None,
            speaker_names=None,
            waveform_transform_debug_dir=None,
            min_duration=None,
            max_duration=None,
            duration_filter=True,
            min_ref_len=None,
            max_ref_len=None,
            max_num_channels=2,
            ref_len_filter=True,
            mono=True,
            audio_dtype='float32',
            segmented=False,
            time_padding_multiple=1,
            audio_backend=None,
            exclude=set(),
            join_transcript=False,
            bucket=None,
            pop_meta=False,
            string_array_encoding='utf_16_le',
            _print=print,
            debug_short_long_records_features_from_whole_normalized_signal=False
    ):
        self.debug_short_long_records_features_from_whole_normalized_signal = debug_short_long_records_features_from_whole_normalized_signal
        self.join_transcript = join_transcript
        self.max_duration = max_duration
        self.text_pipelines = text_pipelines
        self.frontend = frontend
        self.sample_rate = sample_rate
        self.waveform_transform_debug_dir = waveform_transform_debug_dir
        self.segmented = segmented
        self.time_padding_multiple = time_padding_multiple
        self.mono = mono
        self.audio_backend = audio_backend
        self.audio_dtype = audio_dtype

        data_paths = data_paths if isinstance(data_paths,
                                              list) else [data_paths]
        exclude = set(exclude)

        tic = time.time()

        transcripts_read = list(map(transcripts.load, data_paths))
        _print('Dataset reading time: ',
               time.time() - tic)
        tic = time.time()

        #TODO group only segmented = True
        segments_by_audio_path = []
        for transcript in transcripts_read:
            transcript = sorted(transcript, key=transcripts.sort_key)
            transcript = itertools.groupby(transcript,
                                           key=transcripts.group_key)
            for _, example in transcript:
                segments_by_audio_path.append(list(example))

        speaker_names_filtered = set()
        examples_filtered = []
        examples_lens = []
        transcript = []

        duration = lambda example: sum(
            map(transcripts.compute_duration, example))
        segments_by_audio_path.sort(key=duration)

        # TODO: not segmented mode may fail if several examples have same audio_path
        for example in segments_by_audio_path:
            exclude_ok = ((not exclude) or
                          (transcripts.audio_name(example[0]) not in exclude))
            duration_ok = (
                (not duration_filter) or
                (min_duration is None or min_duration <= duration(example)) and
                (max_duration is None or duration(example) <= max_duration))

            if duration_ok and exclude_ok:
                b = bucket(example) if bucket is not None else 0
                for t in example:
                    t['bucket'] = b
                    t['ref'] = t.get('ref', transcripts.ref_missing)
                    t['begin'] = t.get('begin', transcripts.time_missing)
                    t['end'] = t.get('end', transcripts.time_missing)
                    t['channel'] = t.get('channel',
                                         transcripts.channel_missing)

                examples_filtered.append(example)
                transcript.extend(example)
                examples_lens.append(len(example))

        self.speaker_names = transcripts.collect_speaker_names(
            transcript,
            speaker_names=speaker_names or [],
            num_speakers=max_num_channels,
            set_speaker=True)

        _print('Dataset construction time: ',
               time.time() - tic)
        tic = time.time()

        self.bucket = torch.ShortTensor(
            [e[0]['bucket'] for e in examples_filtered])
        self.audio_path = utils.TensorBackedStringArray(
            [e[0]['audio_path'] for e in examples_filtered],
            encoding=string_array_encoding)
        self.ref = utils.TensorBackedStringArray(
            [t['ref'] for t in transcript], encoding=string_array_encoding)
        self.begin = torch.DoubleTensor([t['begin'] for t in transcript])
        self.end = torch.DoubleTensor([t['end'] for t in transcript])
        self.channel = torch.CharTensor([t['channel'] for t in transcript])
        self.speaker = torch.LongTensor([t['speaker'] for t in transcript])
        self.cumlen = torch.ShortTensor(examples_lens).cumsum(
            dim=0, dtype=torch.int64)
        if pop_meta:
            self.meta = {}
        else:
            self.meta = {self.example_id(t): t for t in transcript}
            if self.join_transcript:
                #TODO: harmonize dummy transcript of replace_transcript case (and fix channel)
                self.meta.update({
                    self.example_id(t_src): t_tgt
                    for e in examples_filtered for t_src, t_tgt in [(
                        dict(audio_path=e[0]['audio_path'],
                             begin=transcripts.time_missing,
                             end=transcripts.time_missing,
                             channel=transcripts.channel_missing,
                             speaker=transcripts.speaker_missing),
                        dict(audio_path=e[0]['audio_path'],
                             begin=0.0,
                             end=audio.compute_duration(e[0]['audio_path'],
                                                        backend=None),
                             channel=transcripts.channel_missing,
                             speaker=transcripts.speaker_missing,
                             ref=' '.join(
                                 filter(bool, [t.get('ref', '')
                                               for t in e]))))]
                })

        _print('Dataset tensors creation time: ', time.time() - tic)
Ejemplo n.º 3
0
def main(args, ext_json=['.json', '.json.gz']):
    utils.enable_jit_fusion()

    assert args.output_json or args.output_html or args.output_txt or args.output_csv, \
     'at least one of the output formats must be provided'
    os.makedirs(args.output_path, exist_ok=True)

    audio_data_paths = set(
        p for f in args.input_path
        for p in ([os.path.join(f, g)
                   for g in os.listdir(f)] if os.path.isdir(f) else [f])
        if os.path.isfile(p) and any(map(p.endswith, args.ext)))
    json_data_paths = set(
        p for p in args.input_path if any(map(p.endswith, ext_json))
        and not utils.strip_suffixes(p, ext_json) in audio_data_paths)

    data_paths = list(audio_data_paths | json_data_paths)

    exclude = set([
        os.path.splitext(basename)[0]
        for basename in os.listdir(args.output_path)
        if basename.endswith('.json')
    ]) if args.skip_processed else None

    data_paths = [
        path for path in data_paths
        if exclude is None or os.path.basename(path) not in exclude
    ]

    text_pipeline, frontend, model, generator = setup(args)
    val_dataset = datasets.AudioTextDataset(
        data_paths, [text_pipeline],
        args.sample_rate,
        frontend=frontend if not args.frontend_in_model else None,
        mono=args.mono,
        time_padding_multiple=args.batch_time_padding_multiple,
        audio_backend=args.audio_backend,
        exclude=exclude,
        max_duration=args.transcribe_first_n_sec,
        mode='batched_channels'
        if args.join_transcript else 'batched_transcript',
        string_array_encoding=args.dataset_string_array_encoding,
        debug_short_long_records_features_from_whole_normalized_signal=args.
        debug_short_long_records_features_from_whole_normalized_signal)
    print('Examples count: ', len(val_dataset))
    val_meta = val_dataset.pop_meta()
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=None,
        collate_fn=val_dataset.collate_fn,
        num_workers=args.num_workers)
    csv_sep = dict(tab='\t', comma=',')[args.csv_sep]
    csv_lines = []  # only used if args.output_csv is True

    oom_handler = utils.OomHandler(max_retries=args.oom_retries)
    for i, (meta, s, x, xlen, y, ylen) in enumerate(val_data_loader):
        print(f'Processing: {i}/{len(val_dataset)}')
        meta = [val_meta[t['example_id']] for t in meta]

        audio_path = meta[0]['audio_path']
        audio_name = transcripts.audio_name(audio_path)
        begin_end = [dict(begin=t['begin'], end=t['end']) for t in meta]
        begin = torch.tensor([t['begin'] for t in begin_end],
                             dtype=torch.float)
        end = torch.tensor([t['end'] for t in begin_end], dtype=torch.float)
        #TODO WARNING assumes frontend not in dataset
        if not args.frontend_in_model:
            print('\n' * 10 + 'WARNING\n' * 5)
            print(
                'transcribe.py assumes frontend in model, in other case time alignment was incorrect'
            )
            print('WARNING\n' * 5 + '\n')

        duration = x.shape[-1] / args.sample_rate
        channel = [t['channel'] for t in meta]
        speaker = [t['speaker'] for t in meta]
        speaker_name = [t['speaker_name'] for t in meta]

        if x.numel() == 0:
            print(f'Skipping empty [{audio_path}].')
            continue

        try:
            tic = time.time()
            y, ylen = y.to(args.device), ylen.to(args.device)
            log_probs, logits, olen = model(
                x.squeeze(1).to(args.device), xlen.to(args.device))

            print('Input:', audio_name)
            print('Input time steps:', log_probs.shape[-1],
                  '| target time steps:', y.shape[-1])
            print(
                'Time: audio {audio:.02f} sec | processing {processing:.02f} sec'
                .format(audio=sum(map(transcripts.compute_duration, meta)),
                        processing=time.time() - tic))

            ts: shaping.Bt = duration * torch.linspace(
                0, 1, steps=log_probs.shape[-1],
                device=log_probs.device).unsqueeze(0).expand(x.shape[0], -1)

            ref_segments = [[
                dict(channel=channel[i],
                     begin=begin_end[i]['begin'],
                     end=begin_end[i]['end'],
                     ref=text_pipeline.postprocess(
                         text_pipeline.preprocess(meta[i]['ref'])))
            ] for i in range(len(meta))]
            hyp_segments = [
                alternatives[0] for alternatives in generator.generate(
                    tokenizer=text_pipeline.tokenizer,
                    log_probs=log_probs,
                    begin=begin,
                    end=end,
                    output_lengths=olen,
                    time_stamps=ts,
                    segment_text_key='hyp',
                    segment_extra_info=[
                        dict(speaker=s, speaker_name=sn, channel=c)
                        for s, sn, c in zip(speaker, speaker_name, channel)
                    ])
            ]
            hyp_segments = [
                transcripts.map_text(text_pipeline.postprocess, hyp=hyp)
                for hyp in hyp_segments
            ]
            hyp, ref = '\n'.join(
                transcripts.join(hyp=h)
                for h in hyp_segments).strip(), '\n'.join(
                    transcripts.join(ref=r) for r in ref_segments).strip()
            if args.verbose:
                print('HYP:', hyp)
            print('CER: {cer:.02%}'.format(cer=metrics.cer(hyp=hyp, ref=ref)))

            tic_alignment = time.time()
            if args.align and y.numel() > 0:
                alignment: shaping.BY = ctc.alignment(
                    log_probs.permute(2, 0, 1),
                    y[:, 0, :],  # assumed that 0 channel is char labels
                    olen,
                    ylen[:, 0],
                    blank=text_pipeline.tokenizer.eps_id,
                    pack_backpointers=args.pack_backpointers)
                aligned_ts: shaping.Bt = ts.gather(1, alignment)

                ref_segments = [
                    alternatives[0] for alternatives in generator.generate(
                        tokenizer=text_pipeline.tokenizer,
                        log_probs=torch.nn.functional.one_hot(
                            y[:,
                              0, :], num_classes=log_probs.shape[1]).permute(
                                  0, 2, 1),
                        begin=begin,
                        end=end,
                        output_lengths=ylen,
                        time_stamps=aligned_ts,
                        segment_text_key='ref',
                        segment_extra_info=[
                            dict(speaker=s, speaker_name=sn, channel=c)
                            for s, sn, c in zip(speaker, speaker_name, channel)
                        ])
                ]
                ref_segments = [
                    transcripts.map_text(text_pipeline.postprocess, hyp=ref)
                    for ref in ref_segments
                ]
            oom_handler.reset()
        except:
            if oom_handler.try_recover(model.parameters()):
                print(f'Skipping {i} / {len(val_dataset)}')
                continue
            else:
                raise

        print('Alignment time: {:.02f} sec'.format(time.time() -
                                                   tic_alignment))

        ref_transcript, hyp_transcript = [
            sorted(transcripts.flatten(segments), key=transcripts.sort_key)
            for segments in [ref_segments, hyp_segments]
        ]

        if args.max_segment_duration:
            if ref:
                ref_segments = list(
                    transcripts.segment_by_time(ref_transcript,
                                                args.max_segment_duration))
                hyp_segments = list(
                    transcripts.segment_by_ref(hyp_transcript, ref_segments))
            else:
                hyp_segments = list(
                    transcripts.segment_by_time(hyp_transcript,
                                                args.max_segment_duration))
                ref_segments = [[] for _ in hyp_segments]

        #### HACK for diarization
        elif args.ref_transcript_path and args.join_transcript:
            audio_name_hack = audio_name.split('.')[0]
            #TODO: normalize ref field
            ref_segments = [[t] for t in sorted(transcripts.load(
                os.path.join(args.ref_transcript_path, audio_name_hack +
                             '.json')),
                                                key=transcripts.sort_key)]
            hyp_segments = list(
                transcripts.segment_by_ref(hyp_transcript,
                                           ref_segments,
                                           set_speaker=True,
                                           soft=False))
        #### END OF HACK

        has_ref = bool(transcripts.join(ref=transcripts.flatten(ref_segments)))

        transcript = []
        for hyp_transcript, ref_transcript in zip(hyp_segments, ref_segments):
            hyp, ref = transcripts.join(hyp=hyp_transcript), transcripts.join(
                ref=ref_transcript)

            transcript.append(
                dict(audio_path=audio_path,
                     ref=ref,
                     hyp=hyp,
                     speaker_name=transcripts.speaker_name(ref=ref_transcript,
                                                           hyp=hyp_transcript),
                     words=metrics.align_words(
                         *metrics.align_strings(hyp=hyp, ref=ref))
                     if args.align_words else [],
                     words_ref=ref_transcript,
                     words_hyp=hyp_transcript,
                     **transcripts.summary(hyp_transcript),
                     **(dict(cer=metrics.cer(hyp=hyp, ref=ref))
                        if has_ref else {})))

        transcripts.collect_speaker_names(transcript,
                                          set_speaker_data=True,
                                          num_speakers=2)

        filtered_transcript = list(
            transcripts.prune(transcript,
                              align_boundary_words=args.align_boundary_words,
                              cer=args.prune_cer,
                              duration=args.prune_duration,
                              gap=args.prune_gap,
                              allowed_unk_count=args.prune_unk,
                              num_speakers=args.prune_num_speakers))

        print('Filtered segments:', len(filtered_transcript), 'out of',
              len(transcript))

        if args.output_json:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.json')
            print(transcripts.save(transcript_path, filtered_transcript))

        if args.output_html:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.html')
            print(
                vis.transcript(transcript_path, args.sample_rate, args.mono,
                               transcript, filtered_transcript))

        if args.output_txt:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.txt')
            with open(transcript_path, 'w') as f:
                f.write(' '.join(t['hyp'].strip()
                                 for t in filtered_transcript))
            print(transcript_path)

        if args.output_csv:
            assert len({t['audio_path'] for t in filtered_transcript}) == 1
            audio_path = filtered_transcript[0]['audio_path']
            hyp = ' '.join(t['hyp'].strip() for t in filtered_transcript)
            begin = min(t['begin'] for t in filtered_transcript)
            end = max(t['end'] for t in filtered_transcript)
            csv_lines.append(
                csv_sep.join([audio_path, hyp,
                              str(begin),
                              str(end)]))

        if args.logits:
            logits_file_path = os.path.join(args.output_path,
                                            audio_name + '.pt')
            if args.logits_crop:
                begin_end = [
                    dict(
                        zip(['begin', 'end'], [
                            t['begin'] + c / float(o) * (t['end'] - t['begin'])
                            for c in args.logits_crop
                        ])) for o, t in zip(olen, begin_end)
                ]
                logits_crop = [slice(*args.logits_crop) for o in olen]
            else:
                logits_crop = [slice(int(o)) for o in olen]

            # TODO: filter ref / hyp by channel?
            torch.save([
                dict(audio_path=audio_path,
                     logits=l[..., logits_crop[i]],
                     **begin_end[i],
                     ref=ref,
                     hyp=hyp) for i, l in enumerate(logits.cpu())
            ], logits_file_path)
            print(logits_file_path)

        print('Done: {:.02f} sec\n'.format(time.time() - tic))

    if args.output_csv:
        transcript_path = os.path.join(args.output_path, 'transcripts.csv')
        with open(transcript_path, 'w') as f:
            f.write('\n'.join(csv_lines))
        print(transcript_path)
Ejemplo n.º 4
0
def main(args):
    utils.enable_jit_fusion()

    assert args.output_json or args.output_html or args.output_txt or args.output_csv, \
     'at least one of the output formats must be provided'
    os.makedirs(args.output_path, exist_ok=True)
    data_paths = [
        p for f in args.input_path
        for p in ([os.path.join(f, g)
                   for g in os.listdir(f)] if os.path.isdir(f) else [f])
        if os.path.isfile(p) and any(map(p.endswith, args.ext))
    ] + [
        p
        for p in args.input_path if any(map(p.endswith, ['.json', '.json.gz']))
    ]
    exclude = set([
        os.path.splitext(basename)[0]
        for basename in os.listdir(args.output_path)
        if basename.endswith('.json')
    ] if args.skip_processed else [])
    data_paths = [
        path for path in data_paths if os.path.basename(path) not in exclude
    ]

    labels, frontend, model, decoder = setup(args)
    val_dataset = datasets.AudioTextDataset(
        data_paths, [labels],
        args.sample_rate,
        frontend=None,
        segmented=True,
        mono=args.mono,
        time_padding_multiple=args.batch_time_padding_multiple,
        audio_backend=args.audio_backend,
        exclude=exclude,
        max_duration=args.transcribe_first_n_sec,
        join_transcript=args.join_transcript,
        string_array_encoding=args.dataset_string_array_encoding)
    num_examples = len(val_dataset)
    print('Examples count: ', num_examples)
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=None,
        collate_fn=val_dataset.collate_fn,
        num_workers=args.num_workers)
    csv_sep = dict(tab='\t', comma=',')[args.csv_sep]
    output_lines = []  # only used if args.output_csv is True

    oom_handler = utils.OomHandler(max_retries=args.oom_retries)
    for i, (meta, s, x, xlen, y, ylen) in enumerate(val_data_loader):
        print(f'Processing: {i}/{num_examples}')

        meta = [val_dataset.meta.get(m['example_id']) for m in meta]
        audio_path = meta[0]['audio_path']

        if x.numel() == 0:
            print(f'Skipping empty [{audio_path}].')
            continue

        begin = meta[0]['begin']
        end = meta[0]['end']
        audio_name = transcripts.audio_name(audio_path)

        try:
            tic = time.time()
            y, ylen = y.to(args.device), ylen.to(args.device)
            log_probs, olen = model(
                x.squeeze(1).to(args.device), xlen.to(args.device))

            #speech = vad.detect_speech(x.squeeze(1), args.sample_rate, args.window_size, aggressiveness = args.vad, window_size_dilate = args.window_size_dilate)
            #speech = vad.upsample(speech, log_probs)
            #log_probs.masked_fill_(models.silence_space_mask(log_probs, speech, space_idx = labels.space_idx, blank_idx = labels.blank_idx), float('-inf'))

            decoded = decoder.decode(log_probs, olen)

            print('Input:', audio_name)
            print('Input time steps:', log_probs.shape[-1],
                  '| target time steps:', y.shape[-1])
            print(
                'Time: audio {audio:.02f} sec | processing {processing:.02f} sec'
                .format(audio=sum(
                    transcripts.compute_duration(t) for t in meta),
                        processing=time.time() - tic))

            ts = (x.shape[-1] / args.sample_rate) * torch.linspace(
                0, 1,
                steps=log_probs.shape[-1]).unsqueeze(0) + torch.FloatTensor(
                    [t['begin'] for t in meta]).unsqueeze(1)
            channel = [t['channel'] for t in meta]
            speaker = [t['speaker'] for t in meta]
            ref_segments = [[
                dict(channel=channel[i],
                     begin=meta[i]['begin'],
                     end=meta[i]['end'],
                     ref=labels.decode(y[i, 0, :ylen[i]].tolist()))
            ] for i in range(len(decoded))]
            hyp_segments = [
                labels.decode(decoded[i],
                              ts[i],
                              channel=channel[i],
                              replace_blank=True,
                              replace_blank_series=args.replace_blank_series,
                              replace_repeat=True,
                              replace_space=False,
                              speaker=speaker[i] if isinstance(
                                  speaker[i], str) else None)
                for i in range(len(decoded))
            ]

            ref, hyp = '\n'.join(
                transcripts.join(ref=r)
                for r in ref_segments).strip(), '\n'.join(
                    transcripts.join(hyp=h) for h in hyp_segments).strip()
            if args.verbose:
                print('HYP:', hyp)
            print('CER: {cer:.02%}'.format(cer=metrics.cer(hyp=hyp, ref=ref)))

            tic_alignment = time.time()
            if args.align and y.numel() > 0:
                #if ref_full:# and not ref:
                #	#assert len(set(t['channel'] for t in meta)) == 1 or all(t['type'] != 'channel' for t in meta)
                #	#TODO: add space at the end
                #	channel = torch.ByteTensor(channel).repeat_interleave(log_probs.shape[-1]).reshape(1, -1)
                #	ts = ts.reshape(1, -1)
                #	log_probs = log_probs.transpose(0, 1).unsqueeze(0).flatten(start_dim = -2)
                #	olen = torch.tensor([log_probs.shape[-1]], device = log_probs.device, dtype = torch.long)
                #	y = y_full[None, None, :].to(y.device)
                #	ylen = torch.tensor([[y.shape[-1]]], device = log_probs.device, dtype = torch.long)
                #	segments = [([], sum([h for r, h in segments], []))]

                alignment = ctc.alignment(
                    log_probs.permute(2, 0, 1),
                    y.squeeze(1),
                    olen,
                    ylen.squeeze(1),
                    blank=labels.blank_idx,
                    pack_backpointers=args.pack_backpointers)
                ref_segments = [
                    labels.decode(y[i, 0, :ylen[i]].tolist(),
                                  ts[i],
                                  alignment[i],
                                  channel=channel[i],
                                  speaker=speaker[i],
                                  key='ref',
                                  speakers=val_dataset.speakers)
                    for i in range(len(decoded))
                ]
            oom_handler.reset()
        except:
            if oom_handler.try_recover(model.parameters()):
                print(f'Skipping {i} / {num_examples}')
                continue
            else:
                raise

        print('Alignment time: {:.02f} sec'.format(time.time() -
                                                   tic_alignment))

        if args.max_segment_duration:
            ref_transcript, hyp_transcript = [
                list(sorted(sum(segments, []), key=transcripts.sort_key))
                for segments in [ref_segments, hyp_segments]
            ]
            if ref:
                ref_segments = list(
                    transcripts.segment(ref_transcript,
                                        args.max_segment_duration))
                hyp_segments = list(
                    transcripts.segment(hyp_transcript, ref_segments))
            else:
                hyp_segments = list(
                    transcripts.segment(hyp_transcript,
                                        args.max_segment_duration))
                ref_segments = [[] for _ in hyp_segments]

        transcript = [
            dict(audio_path=audio_path,
                 ref=ref,
                 hyp=hyp,
                 speaker=transcripts.speaker(ref=ref_transcript,
                                             hyp=hyp_transcript),
                 cer=metrics.cer(hyp=hyp, ref=ref),
                 words=metrics.align_words(hyp=hyp, ref=ref)[-1]
                 if args.align_words else [],
                 alignment=dict(ref=ref_transcript, hyp=hyp_transcript),
                 **transcripts.summary(hyp_transcript)) for ref_transcript,
            hyp_transcript in zip(ref_segments, hyp_segments)
            for ref, hyp in [(transcripts.join(ref=ref_transcript),
                              transcripts.join(hyp=hyp_transcript))]
        ]
        filtered_transcript = list(
            transcripts.prune(transcript,
                              align_boundary_words=args.align_boundary_words,
                              cer=args.cer,
                              duration=args.duration,
                              gap=args.gap,
                              unk=args.unk,
                              num_speakers=args.num_speakers))

        print('Filtered segments:', len(filtered_transcript), 'out of',
              len(transcript))

        if args.output_json:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.json')
            print(transcript_path)
            with open(transcript_path, 'w') as f:
                json.dump(filtered_transcript,
                          f,
                          ensure_ascii=False,
                          sort_keys=True,
                          indent=2)

        if args.output_html:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.html')
            print(transcript_path)
            vis.transcript(transcript_path, args.sample_rate, args.mono,
                           transcript, filtered_transcript)

        if args.output_txt:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.txt')
            print(transcript_path)
            with open(transcript_path, 'w') as f:
                f.write(hyp)

        if args.output_csv:
            output_lines.append(
                csv_sep.join((audio_path, hyp, str(begin), str(end))) + '\n')

        print('Done: {:.02f} sec\n'.format(time.time() - tic))

    if args.output_csv:
        with open(os.path.join(args.output_path, 'transcripts.csv'), 'w') as f:
            f.writelines(output_lines)
Ejemplo n.º 5
0
def logits(lang, logits, audio_name = None, MAX_ENTROPY = 1.0):
	good_audio_name = set(map(str.strip, open(audio_name[0])) if os.path.exists(audio_name[0]) else audio_name) if audio_name is not None else []
	labels = datasets.Labels(datasets.Language(lang))
	decoder = decoders.GreedyDecoder()
	tick_params = lambda ax, labelsize = 2.5, length = 0, **kwargs: ax.tick_params(axis = 'both', which = 'both', labelsize = labelsize, length = length, **kwargs) or [ax.set_linewidth(0) for ax in ax.spines.values()]
	logits_path = logits + '.html'
	html = open(logits_path, 'w')
	html.write('<html><head>' + meta_charset + f'</head><body><script>{play_script}{onclick_img_script}</script>')

	for i, t in enumerate(torch.load(logits)):
		audio_path, logits = t['audio_path'], t['logits']
		words = t.get('words', [t])
		y = t.get('y', torch.zeros(1, 0, dtype = torch.long))
		begin = t.get('begin', '')
		end = t.get('end', '')
		audio_name = transcripts.audio_name(audio_path)
		extra_metrics = dict(cer = t['cer']) if 'cer' in t else {}

		if good_audio_name and audio_name not in good_audio_name:
			continue

		log_probs = F.log_softmax(logits, dim = 0)
		entropy = models.entropy(log_probs, dim = 0, sum = False)
		log_probs_ = F.log_softmax(logits[:-1], dim = 0)
		entropy_ = models.entropy(log_probs_, dim = 0, sum = False)
		margin = models.margin(log_probs, dim = 0)
		#energy = features.exp().sum(dim = 0)[::2]

		plt.figure(figsize = (6, 2))
		ax = plt.subplot(211)
		plt.imshow(logits, aspect = 'auto')
		plt.xlim(0, logits.shape[-1] - 1)
		#plt.yticks([])
		plt.axis('off')
		tick_params(plt.gca())
		#plt.subplots_adjust(left = 0, right = 1, bottom = 0.12, top = 0.95)

		plt.subplot(212, sharex = ax)
		prob_top1, prob_top2 = log_probs.exp().topk(2, dim = 0).values
		plt.hlines(1.0, 0, entropy.shape[-1] - 1, linewidth = 0.2)
		artist_prob_top1, = plt.plot(prob_top1, 'b', linewidth = 0.3)
		artist_prob_top2, = plt.plot(prob_top2, 'g', linewidth = 0.3)
		artist_entropy, = plt.plot(entropy, 'r', linewidth = 0.3)
		artist_entropy_, = plt.plot(entropy_, 'yellow', linewidth = 0.3)
		plt.legend([artist_entropy, artist_entropy_, artist_prob_top1, artist_prob_top2],
					['entropy', 'entropy, no blank', 'top1 prob', 'top2 prob'],
					loc = 1,
					fontsize = 'xx-small',
					frameon = False)
		
		for b, e, v in zip(*models.rle1d(entropy > MAX_ENTROPY)):
			if bool(v):
				plt.axvspan(int(b), int(e), color='red', alpha=0.2)

		plt.ylim(0, 3.0)
		plt.xlim(0, entropy.shape[-1] - 1)

		decoded = decoder.decode(log_probs.unsqueeze(0), K = 5)[0]
		xlabels = list(
			map(
				'\n'.join,
				zip(
					*[
						labels.decode(d, replace_blank = '.', replace_space = '_', replace_repeat = False, strip = False)
						for d in decoded
					]
				)
			)
		)
		plt.xticks(torch.arange(entropy.shape[-1]), xlabels, fontfamily = 'monospace')
		tick_params(plt.gca())

		if y.numel() > 0:
			alignment = ctc.alignment(
				log_probs.unsqueeze(0).permute(2, 0, 1),
				y.unsqueeze(0).long(),
				torch.LongTensor([log_probs.shape[-1]]),
				torch.LongTensor([len(y)]),
				blank = len(log_probs) - 1
			).squeeze(0)
			
			ax = plt.gca().secondary_xaxis('top')
			ref, ref_ = labels.decode(y.tolist(), replace_blank = '.', replace_space = '_', replace_repeat = False, strip = False), alignment
			ax.set_xticklabels(ref)
			ax.set_xticks(ref_)
			tick_params(ax, colors = 'red')

		#k = 0
		#for i, c in enumerate(ref + ' '):
		#	if c == ' ':
		#		plt.axvspan(ref_[k] - 1, ref_[i - 1] + 1, facecolor = 'gray', alpha = 0.2)
		#		k = i + 1

		plt.subplots_adjust(left = 0, right = 1, bottom = 0.12, top = 0.95)

		buf = io.BytesIO()
		plt.savefig(buf, format = 'jpg', dpi = 600)
		plt.close()

		html.write(f'<h4>{audio_name}')
		html.write(' | '.join('{k}: {v:.02f}' for k, v in extra_metrics.items()))
		html.write('</h4>')
		html.write(fmt_alignment(words))
		html.write('<img data-begin="{begin}" data-end="{end}" data-channel="{channel}" onclick="onclick_img(event)" style="width:100%" src="data:image/jpeg;base64,{encoded}"></img>\n'.format(channel = i, begin = begin, end = end, encoded = base64.b64encode(buf.getvalue()).decode()))
		html.write(fmt_audio(audio_path = audio_path, channel = i))
		html.write('<hr/>')
	html.write('</body></html>')
	return logits_path
Ejemplo n.º 6
0
def transcript(html_path, sample_rate, mono, transcript, filtered_transcript = [], duration = None, NA = 'N/A', default_channel = 0):
	if isinstance(transcript, str):
		transcript = json.load(open(transcript))

	audio_path = transcript[0]['audio_path']
	audio_name = transcripts.audio_name(audio_path)

	signal, sample_rate = audio.read_audio(audio_path, sample_rate = sample_rate, mono = mono, duration = duration)
	channel_or_default = lambda channel: default_channel if channel == transcripts.channel_missing else channel 

	def fmt_link(ref = '', hyp = '', channel = default_channel, begin = transcripts.time_missing, end = transcripts.time_missing, speaker = transcripts.speaker_missing, i = '', j = '', audio_path = '', special_begin = 0, special_end = 1, **kwargs):
		span = ref in [special_begin, special_end] or begin == transcripts.time_missing or end == transcripts.time_missing
		tag_open = '<span' if span else f'<a onclick="return play(event, {channel_or_default(channel)}, {begin}, {end})"'
		tag_attr = f' title="channel{channel}. speaker{speaker}: {begin:.04f} - {end:.04f} | {i} - {j}" href="#" target="_blank">'
		tag_contents = (ref + hyp) if isinstance(ref, str) else (f'{begin:.02f}' if begin != transcripts.time_missing else NA) if ref == special_begin else (f'{end:.02f}' if end != transcripts.time_missing else NA) if ref == special_end else (f'{end - begin:.02f}' if begin != transcripts.time_missing and end != transcripts.time_missing else NA)
		tag_close = '</span>' if span else '</a>'
		return tag_open + tag_attr + tag_contents + tag_close
	
	fmt_words = lambda rh: ' '.join(fmt_link(**w) for w in rh)
	fmt_begin_end = 'data-begin="{begin}" data-end="{end}"'.format

	html = open(html_path, 'w')
	style = ' '.join(f'.speaker{i} {{background-color : {c}; }}' for i, c in enumerate(speaker_colors)) + ' '.join(f'.channel{i} {{background-color : {c}; }}' for i, c in enumerate(channel_colors)) + ' a {text-decoration: none;} .reference{opacity:0.4} .channel{margin:0px} .ok{background-color:green} .m0{margin:0px} .top{vertical-align:top}' 
	
	html.write(f'<html><head>' + meta_charset + f'<style>{style}</style></head><body>')
	html.write(f'<script>{play_script}{onclick_svg_script}</script>')
	html.write(
		f'<div style="overflow:auto"><h4 style="float:left">{audio_name}</h4><h5 style="float:right">0.000000</h5></div>'
	)
	html_speaker_barcode = fmt_svg_speaker_barcode(transcript, begin = 0.0, end = signal.shape[-1] / sample_rate)

	html.writelines(
		f'<figure class="m0"><figcaption><a href="#" download="channel{c}.{audio_name}" onclick="return download_audio(event, {c})">channel #{c}:</a></figcaption><audio ontimeupdate="update_span(ontimeupdate_(event), event)" onpause="onpause_(event)" id="audio{c}" style="width:100%" controls src="{uri_audio}"></audio>{html_speaker_barcode}</figure><hr />'
		for c,
		uri_audio in enumerate(
			audio_data_uri(signal[channel], sample_rate) for channel in ([0, 1] if len(signal) == 2 else []) + [...]
		)
	)
	html.write('<pre class="channel"><h3 class="channel0 channel">hyp #0:<span class="subtitle"></span></h3></pre>')
	html.write('<pre class="channel"><h3 class="channel0 reference channel">ref #0:<span class="subtitle"></span></h3></pre>')
	html.write('<pre class="channel" style="margin-top: 10px"><h3 class="channel1 channel">hyp #1:<span class="subtitle"></span></h3></pre>')
	html.write('<pre class="channel"><h3 class="channel1 reference channel">ref #1:<span class="subtitle"></span></h3></pre>')

	def fmt_th():
		idx_th = '<th>#</th>'
		speaker_th = '<th>speaker</th>'
		begin_th = '<th>begin</th>'
		end_th = '<th>end</th>'
		duration_th = '<th>dur</th>'
		hyp_th = '<th style="width:50%">hyp</th>'
		ref_th = '<th style="width:50%">ref</th>' + begin_th + end_th + duration_th + '<th>cer</th>' 
		return '<tr>' + idx_th + speaker_th + begin_th + end_th + duration_th + hyp_th + ref_th
 
	def fmt_tr(i, ok, t, words, hyp, ref, channel, speaker, speaker_name, cer):
		idx_td = f'''<td class="top {ok and 'ok'}">#{i}</td>'''
		speaker_td = f'<td class="speaker{speaker}" title="speaker{speaker}">{speaker_name}</td>'
		left_td = f'<td class="top">{fmt_link(0, **transcripts.summary(hyp, ij = True))}</td><td class="top">{fmt_link(1, **transcripts.summary(hyp, ij = True))}</td><td class="top">{fmt_link(2, **transcripts.summary(hyp, ij = True))}</td>'
		hyp_td = f'<td class="top hyp" data-channel="{channel}" data-speaker="{speaker}" {fmt_begin_end(**transcripts.summary(hyp, ij = True))}>{fmt_words(hyp)}{fmt_alignment(words, hyp = True, prefix = "", tag = "<template>")}</td>'
		ref_td = f'<td class="top reference ref" data-channel="{channel}" data-speaker="{speaker}" {fmt_begin_end(**transcripts.summary(ref, ij = True))}>{fmt_words(ref)}{fmt_alignment(words, ref = True, prefix = "", tag = "<template>")}</td>'
		right_td = f'<td class="top">{fmt_link(0, **transcripts.summary(ref, ij = True))}</td><td class="top">{fmt_link(1, **transcripts.summary(ref, ij = True))}</td><td class="top">{fmt_link(2, **transcripts.summary(ref, ij = True))}</td>'
		cer_td = f'<td class="top">' + (f'{cer:.2%}' if cer != transcripts._er_missing else NA) + '</td>' 
		return f'<tr class="channel{channel} speaker{speaker}">' + idx_td + speaker_td + left_td + hyp_td + ref_td + right_td + cer_td + '</tr>\n'

	html.write('<hr/><table style="width:100%">')
	html.write(fmt_th())
	html.writelines(fmt_tr(i, t in filtered_transcript, t, t.get('words', [t]), t.get('words_hyp', [t]), t.get('words_ref', [t]), t.get('channel', transcripts.channel_missing), t.get('speaker', transcripts.speaker_missing), t.get('speaker_name', 'speaker{}'.format(t.get('speaker', transcripts.speaker_missing))), t.get('cer', transcripts._er_missing)) for i, t in enumerate(transcripts.sort(transcript)))
	html.write(f'</tbody></table><script>{subtitle_script}</script></body></html>')
	return html_path
Ejemplo n.º 7
0
def label(output_path, transcript, info, page_size, prefix):
	if isinstance(transcript, str):
		transcript = json.load(open(transcript))
	if isinstance(info, str):
		info = json.load(open(info))
	transcript = {transcripts.audio_name(t): t for t in transcript}

	page_count = int(math.ceil(len(info) / page_size))
	for p in range(page_count):
		html_path = output_path + f'.page{p}.html'
		html = open(html_path, 'w')
		html.write(
			'<html><head>' + meta_charset + '<style>figure{margin:0} h6{margin:0}</style></head><body onkeydown="return onkeydown_(event)">'
		)
		html.write('''<script>
function export_user_input()
{
	const data_text_plain_base64_encode_utf8 = str => 'data:text/plain;base64,' + btoa(encodeURIComponent(str).replace(/%([0-9A-F]{2})/g, function(match, p1) {return String.fromCharCode(parseInt(p1, 16)) }));
	
	const after = Array.from(document.querySelectorAll('input.after'));
	const data = after.map(input => ({audio_name : input.name, before : input.dataset.before, after : input.value}));

	const href = data_text_plain_base64_encode_utf8(JSON.stringify(data, null, 2));
	const unixtime = Math.round((new Date()).getTime() / 1000);
	let a = document.querySelector('a');
	const {page, prefix} = a.dataset;
	a.download = `${prefix}_page${page}_time${unixtime}.json`;
	a.href = href;
}

function onkeydown_(evt)
{
		const tab = evt.keyCode == 9, shift = evt.shiftKey;
		const tabIndex = (document.activeElement || {tabIndex : -1}).tabIndex;
		if(tab)
		{
				const newTabIndex = shift ? Math.max(0, tabIndex - 1) : tabIndex + 1;
				const newElem = document.querySelector(`[tabindex="${newTabIndex}"`);
				if(newElem)
						newElem.focus();
				return false;
		}
		return true;
}
		</script>'''
		)
		html.write(
			f'<a data-page="{p}" data-prefix="{prefix}" download="export.json" onclick="export_user_input(); return true" href="#">Export</a>\n'
		)

		k = p * page_size
		for j, i in enumerate(info[k:k + page_size]):
			i['after'] = i.get('after', '')
			t = transcript[i['audio_name']]
			audio_path = t['audio_path'][len('/data/'):]
			html.write('<hr/>\n')
			html.write(
				f'<figure><figcaption>page {p}/{page_count}:<strong>{k + j}</strong><pre>{transcripts.audio_name(t)}</pre></figcaption>{fmt_audio(audio_path)}<figcaption><pre>{t["ref"]}</pre></figcaption></figure>'
			)
			html.write('<h6>before</h6>')
			html.write('<pre name="{audio_name}" class="before">{before}</pre>'.format(**i))
			html.write('<h6>after</h6>')
			html.write(
				'<input tabindex="{tabindex}" name="{audio_name}" class="after" type="text" value="{after}" data-before="{before}">'
				.format(tabindex = j, **i)
			)
		html.write('</body></html>')
		print(html_path)
Ejemplo n.º 8
0
    def __init__(
        self,
        data_paths: typing.List[str],
        text_pipelines: typing.List[text_processing.ProcessingPipeline],
        sample_rate: int,
        mode: str = DEFAULT_MODE,
        frontend: typing.Optional[torch.nn.Module] = None,
        speaker_names: typing.Optional[typing.List[str]] = None,
        max_audio_file_size: typing.Optional[float] = None,  #bytes
        min_duration: typing.Optional[float] = None,
        max_duration: typing.Optional[float] = None,
        max_num_channels: int = 2,
        mono: bool = True,
        audio_dtype: str = 'float32',
        time_padding_multiple: int = 1,
        audio_backend: typing.Optional[str] = None,
        exclude: typing.Optional[typing.Set] = None,
        bucket_fn: typing.Callable[[typing.List[typing.Dict]],
                                   int] = lambda transcript: 0,
        pop_meta: bool = False,
        string_array_encoding: str = 'utf_16_le',
        _print: typing.Callable = print,
        debug_short_long_records_features_from_whole_normalized_signal:
        bool = False,
        duration_from_transcripts: bool = False,
    ):
        self.debug_short_long_records_features_from_whole_normalized_signal = debug_short_long_records_features_from_whole_normalized_signal
        self.mode = mode
        self.min_duration = min_duration
        self.max_duration = max_duration
        self.max_audio_file_size = max_audio_file_size
        self.text_pipelines = text_pipelines
        self.frontend = frontend
        self.sample_rate = sample_rate
        self.time_padding_multiple = time_padding_multiple
        self.mono = mono
        self.audio_backend = audio_backend
        self.audio_dtype = audio_dtype

        data_paths = data_paths if isinstance(data_paths,
                                              list) else [data_paths]

        data_paths_ = []
        for data_path in data_paths:
            if os.path.isdir(data_path):
                data_paths_.extend(
                    os.path.join(data_path, filename) for filename in filter(
                        audio.is_audio, os.listdir(data_path)))
            else:
                data_paths_.append(data_path)
        data_paths = data_paths_

        tic = time.time()

        segments = []
        for path in data_paths:
            if audio.is_audio(path):
                assert self.mono or self.mode != AudioTextDataset.DEFAULT_MODE, 'Only mono audio files allowed as dataset input in default mode'
                if self.mono:
                    transcript = [
                        dict(audio_path=path,
                             channel=transcripts.channel_missing)
                    ]
                else:
                    transcript = [
                        dict(audio_path=path, channel=c)
                        for c in range(max_num_channels)
                    ]
            else:
                transcript = transcripts.load(path)
            segments.extend(transcript)

        _print('Dataset reading time: ', time.time() - tic)
        tic = time.time()

        # get_or_else required because dictionary could contain None values which we want to replace.
        # dict.get doesnt work in this case
        get_or_else = lambda dictionary, key, default: dictionary[
            key] if dictionary.get(key) is not None else default
        for t in segments:
            t['ref'] = get_or_else(t, 'ref', transcripts.ref_missing)
            t['begin'] = get_or_else(t, 'begin', transcripts.time_missing)
            t['end'] = get_or_else(t, 'end', transcripts.time_missing)
            t['channel'] = get_or_else(
                t, 'channel', transcripts.channel_missing
            ) if not self.mono else transcripts.channel_missing

        transcripts.collect_speaker_names(segments,
                                          speaker_names=speaker_names or [],
                                          num_speakers=max_num_channels,
                                          set_speaker_data=True)

        buckets = []
        grouped_segments = []
        transcripts_len = []
        speakers_len = []
        if self.mode == AudioTextDataset.DEFAULT_MODE:
            groupped_transcripts = ((i, [t]) for i, t in enumerate(segments))
        else:
            groupped_transcripts = itertools.groupby(
                sorted(segments, key=transcripts.group_key),
                transcripts.group_key)

        for group_key, transcript in groupped_transcripts:
            transcript = sorted(transcript, key=transcripts.sort_key)
            if self.mode == AudioTextDataset.BATCHED_CHANNELS_MODE:
                transcript = transcripts.join_transcript(
                    transcript,
                    self.mono,
                    duration_from_transcripts=duration_from_transcripts)

            if exclude is not None:
                allowed_audio_names = set(
                    transcripts.audio_name(t) for t in transcript
                    if transcripts.audio_name(t) not in exclude)
            else:
                allowed_audio_names = None

            transcript = transcripts.prune(
                transcript,
                allowed_audio_names=allowed_audio_names,
                duration=(
                    min_duration if min_duration is not None else 0.0,
                    max_duration if max_duration is not None else 24.0 * 3600,
                ),  #24h
                max_audio_file_size=max_audio_file_size)
            transcript = list(transcript)
            for t in transcript:
                t['example_id'] = AudioTextDataset.get_example_id(t)

            if len(transcript) == 0:
                continue

            bucket = bucket_fn(transcript)
            for t in transcript:
                t['bucket'] = bucket
                speakers_len.append(
                    len(t['speaker']) if (
                        isinstance(t['speaker'], list)) else 1)
            buckets.append(bucket)
            grouped_segments.extend(transcript)
            transcripts_len.append(len(transcript))

        _print('Dataset construction time: ', time.time() - tic)
        tic = time.time()

        self.bucket = torch.tensor(buckets, dtype=torch.short)
        self.audio_path = utils.TensorBackedStringArray(
            [t['audio_path'] for t in grouped_segments],
            encoding=string_array_encoding)
        self.ref = utils.TensorBackedStringArray(
            [t['ref'] for t in grouped_segments],
            encoding=string_array_encoding)
        self.begin = torch.tensor([t['begin'] for t in grouped_segments],
                                  dtype=torch.float64)
        self.end = torch.tensor([t['end'] for t in grouped_segments],
                                dtype=torch.float64)
        self.channel = torch.tensor([t['channel'] for t in grouped_segments],
                                    dtype=torch.int8)
        self.example_id = utils.TensorBackedStringArray(
            [t['example_id'] for t in grouped_segments],
            encoding=string_array_encoding)
        if self.mode == AudioTextDataset.BATCHED_CHANNELS_MODE:
            self.speaker = torch.tensor([
                speaker for t in grouped_segments for speaker in t['speaker']
            ],
                                        dtype=torch.int64)
        else:
            self.speaker = torch.tensor(
                [t['speaker'] for t in grouped_segments], dtype=torch.int64)
        self.speaker_len = torch.tensor(speakers_len, dtype=torch.int16)
        self.transcript_cumlen = torch.tensor(transcripts_len,
                                              dtype=torch.int16).cumsum(
                                                  dim=0, dtype=torch.int64)
        if pop_meta:
            self.meta = {}
        else:
            self.meta = {t['example_id']: t for t in grouped_segments}
        _print('Dataset tensors creation time: ', time.time() - tic)
Ejemplo n.º 9
0
def evaluate_model(args,
                   val_data_loaders,
                   model,
                   labels,
                   decoder,
                   error_analyzer,
                   optimizer=None,
                   sampler=None,
                   tensorboard_sink=None,
                   logfile_sink=None,
                   epoch=None,
                   iteration=None):
    _print = utils.get_root_logger_print()

    training = epoch is not None and iteration is not None
    columns = {}
    oom_handler = utils.OomHandler(max_retries=args.oom_retries)
    for val_dataset_name, val_data_loader in val_data_loaders.items():
        _print(f'\n{val_dataset_name}@{iteration}')
        transcript, logits_, y_ = [], [], []
        analyze = args.analyze == [] or (args.analyze is not None
                                         and val_dataset_name in args.analyze)

        model.eval()
        if args.adapt_bn:
            models.reset_bn_running_stats_(model)
            for _ in apply_model(val_data_loader, model, labels, decoder,
                                 args.device, oom_handler):
                pass
        model.eval()
        cpu_list = lambda l: [[t.cpu() for t in t_] for t_ in l]

        tic = time.time()
        ref_, hyp_, audio_path_, loss_, entropy_ = [], [], [], [], []
        for batch_idx, (meta, loss, entropy, hyp, logits, y) in enumerate(
                apply_model(val_data_loader, model, labels, decoder,
                            args.device, oom_handler)):
            loss_.extend(loss.tolist())
            entropy_.extend(entropy.tolist())
            logits_.extend(
                zip(*cpu_list(logits)) if not training and args.logits else [])
            y_.extend(cpu_list(y))
            audio_path_.extend(m['audio_path'] for m in meta)
            ref_.extend(labels[0].normalize_text(m['ref']) for m in meta)
            hyp_.extend(zip(*hyp))
        toc_apply_model = time.time()
        time_sec_val_apply_model = toc_apply_model - tic
        perf.update(dict(time_sec_val_apply_model=time_sec_val_apply_model),
                    prefix=f'datasets_{val_dataset_name}')
        _print(f"Apply model {time_sec_val_apply_model:.1f} sec")

        analyze_args_gen = ((
            hyp,
            ref,
            analyze,
            dict(labels_name=label.name,
                 audio_path=audio_path,
                 audio_name=transcripts.audio_name(audio_path),
                 loss=loss,
                 entropy=entropy),
            label.postprocess_transcript,
        ) for ref, hyp_tuple, audio_path, loss, entropy in zip(
            ref_, hyp_, audio_path_, loss_, entropy_)
                            for label, hyp in zip(labels, hyp_tuple))

        if args.analyze_num_workers <= 0:
            transcript = [
                error_analyzer.analyze(*args) for args in analyze_args_gen
            ]
        else:
            with multiprocessing.pool.Pool(
                    processes=args.analyze_num_workers) as pool:
                transcript = pool.starmap(error_analyzer.analyze,
                                          analyze_args_gen)

        toc_analyze = time.time()
        time_sec_val_analyze = toc_analyze - toc_apply_model
        time_sec_val_total = toc_analyze - tic
        perf.update(dict(time_sec_val_analyze=time_sec_val_analyze,
                         time_sec_val_total=time_sec_val_total),
                    prefix=f'datasets_{val_dataset_name}')
        _print(
            f"Analyze {time_sec_val_analyze:.1f} sec, Total {time_sec_val_total:.1f} sec"
        )

        for i, t in enumerate(transcript if args.verbose else []):
            _print(
                f'{val_dataset_name}@{iteration}: {i // len(labels)} / {len(audio_path_)} | {args.experiment_id}'
            )
            # TODO: don't forget to fix aligned hyp & ref output!
            # hyp = new_transcript['alignment']['hyp'] if analyze else new_transcript['hyp']
            # ref = new_transcript['alignment']['ref'] if analyze else new_transcript['ref']
            _print('REF: {labels_name} "{ref}"'.format(**t))
            _print('HYP: {labels_name} "{hyp}"'.format(**t))
            _print('WER: {labels_name} {wer:.02%} | CER: {cer:.02%}\n'.format(
                **t))

        transcripts_path = os.path.join(
            args.experiment_dir,
            args.train_transcripts_format.format(
                val_dataset_name=val_dataset_name,
                epoch=epoch,
                iteration=iteration)
        ) if training else args.val_transcripts_format.format(
            val_dataset_name=val_dataset_name, decoder=args.decoder)
        for i, label in enumerate(labels):
            transcript_by_label = transcript[i::len(labels)]
            aggregated = error_analyzer.aggregate(transcript_by_label)
            if analyze:
                with open(f'{transcripts_path}.errors.csv', 'w') as f:
                    f.writelines('{hyp},{ref},{error_tag}\n'.format(**w)
                                 for w in aggregated['errors']['words'])

            _print('errors', aggregated['errors']['distribution'])
            cer = torch.FloatTensor([r['cer'] for r in transcript_by_label])
            loss = torch.FloatTensor([r['loss'] for r in transcript_by_label])
            _print('cer', metrics.quantiles(cer))
            _print('loss', metrics.quantiles(loss))
            _print(
                f'{args.experiment_id} {val_dataset_name} {label.name}',
                f'| epoch {epoch} iter {iteration}' if training else '',
                f'| {transcripts_path} |',
                ('Entropy: {entropy:.02f} Loss: {loss:.02f} | WER:  {wer:.02%} CER: {cer:.02%} [{words_easy_errors_easy__cer_pseudo:.02%}],  MER: {mer_wordwise:.02%} DER: {hyp_der:.02%}/{ref_der:.02%}\n'
                 ).format(**aggregated))
            #columns[val_dataset_name + '_' + labels_name] = {'cer' : aggregated['cer_avg'], '.wer' : aggregated['wer_avg'], '.loss' : aggregated['loss_avg'], '.entropy' : aggregated['entropy_avg'], '.cer_easy' : aggregated['cer_easy_avg'], '.cer_hard':  aggregated['cer_hard_avg'], '.cer_missing' : aggregated['cer_missing_avg'], 'E' : dict(value = aggregated['errors_distribution']), 'L' : dict(value = vis.histc_vega(loss, min = 0, max = 3, bins = 20), type = 'vega'), 'C' : dict(value = vis.histc_vega(cer, min = 0, max = 1, bins = 20), type = 'vega'), 'T' : dict(value = [('audio_name', 'cer', 'mer', 'alignment')] + [(r['audio_name'], r['cer'], r['mer'], vis.word_alignment(r['words'])) for r in sorted(r_, key = lambda r: r['mer'], reverse = True)] if analyze else [], type = 'table')}

            if training:
                perf.update(
                    dict(wer=aggregated['wer'],
                         cer=aggregated['cer'],
                         loss=aggregated['loss']),
                    prefix=f'datasets_val_{val_dataset_name}_{label.name}')
                tensorboard_sink.val_stats(iteration, val_dataset_name,
                                           label.name, perf.default())

        with open(transcripts_path, 'w') as f:
            json.dump(transcript,
                      f,
                      ensure_ascii=False,
                      indent=2,
                      sort_keys=True)
        if analyze:
            vis.errors([transcripts_path], audio=args.vis_errors_audio)

        # TODO: transcript got flattened make this code work:
        # if args.logits:
        # 	logits_file_path = args.logits.format(val_dataset_name = val_dataset_name)
        # 	torch.save(
        # 		list(
        # 			sorted([
        # 				dict(
        # 					**r_,
        # 					logits = l_ if not args.logits_topk else models.sparse_topk(l_, args.logits_topk, dim = 0),
        # 					y = y_,
        # 					ydecoded = labels_.decode(y_.tolist())
        # 				) for r,
        # 				l,
        # 				y in zip(transcript, logits_, y_) for labels_,
        # 				r_,
        # 				l_,
        # 				y_ in zip(labels, r, l, y)
        # 			],
        # 					key = lambda r: r['cer'],
        # 					reverse = True)
        # 		),
        # 		logits_file_path
        # 	)
        # 	_print('Logits saved:', logits_file_path)

    checkpoint_path = os.path.join(
        args.experiment_dir,
        args.checkpoint_format.format(epoch=epoch, iteration=iteration)
    ) if training and not args.checkpoint_skip else None
    #if args.exphtml:
    #	columns['checkpoint_path'] = checkpoint_path
    #	exphtml.expjson(args.exphtml, args.experiment_id, epoch = epoch, iteration = iteration, meta = vars(args), columns = columns, tag = 'train' if training else 'test', git_http = args.githttp)
    #	exphtml.exphtml(args.exphtml)

    if training and not args.checkpoint_skip:
        torch.save(
            dict(model_state_dict=models.master_module(model).state_dict(),
                 optimizer_state_dict=optimizer.state_dict()
                 if optimizer is not None else None,
                 amp_state_dict=apex.amp.state_dict() if args.fp16 else None,
                 sampler_state_dict=sampler.state_dict()
                 if sampler is not None else None,
                 epoch=epoch,
                 iteration=iteration,
                 args=vars(args),
                 time=time.time(),
                 labels=[(l.name, str(l)) for l in labels]), checkpoint_path)

    model.train()
Ejemplo n.º 10
0
def errors(input_path,
           include=[],
           exclude=[],
           audio=False,
           output_path=None,
           sortdesc=None,
           topk=None,
           duration=None,
           cer=None,
           wer=None,
           mer=None,
           filter_transcripts=None,
           strip_audio_path_prefix=''):
    include, exclude = (set(
        sum([
            list(map(transcripts.audio_name, json.load(open(file_path))))
            if file_path.endswith('.json') else
            open(file_path).read().splitlines() for file_path in clude
        ], [])) for clude in [include, exclude])
    read_transcript = lambda path: list(
        filter(
            lambda r: (not include or r['audio_name'] in include) and
            (not exclude or r['audio_name'] not in exclude),
            json.load(open(path))
            if isinstance(path, str) else path)) if path is not None else []
    ours, theirs = list(
        transcripts.prune(read_transcript(input_path[0]),
                          duration=duration,
                          cer=cer,
                          wer=wer,
                          mer=mer)), [{
                              r['audio_name']: r
                              for r in read_transcript(transcript)
                          } for transcript in input_path[1:]]

    if filter_transcripts is None:
        if sortdesc is not None:
            filter_transcripts = lambda cat: list(
                sorted(cat, key=lambda utt: utt[0][sortdesc], reverse=True))
        else:
            filter_transcripts = lambda cat: cat

    cat = filter_transcripts([
        [a] + list(filter(None, [t.get(a['audio_name'], None)
                                 for t in theirs])) for a in ours
    ])[slice(topk)]
    cat_by_labels = collections.defaultdict(list)
    for c in cat:
        transcripts_by_labels = collections.defaultdict(list)
        for transcript in c:
            transcripts_by_labels[transcript['labels_name']] += c
        for labels_name, grouped_transcripts in transcripts_by_labels.items():
            cat_by_labels[labels_name] += grouped_transcripts

    # TODO: add sorting https://stackoverflow.com/questions/14267781/sorting-html-table-with-javascript
    html_path = output_path or (input_path[0] + '.html')

    f = open(html_path, 'w')
    f.write(
        '<html><meta charset="utf-8"><style> table{border-collapse:collapse; width: 100%;} audio {width:100%} .br{border-right:2px black solid} tr.first>td {border-top: 1px solid black} tr.any>td {border-top: 1px dashed black}  .nowrap{white-space:nowrap} th.col{width:80px}</style>'
    )
    f.write(
        '<body><table><tr><th></th><th class="col">cer_easy</th><th class="col">cer</th><th class="col">wer_easy</th><th class="col">wer</th><th class="col">mer</th><th></th></tr>'
    )
    f.write('<tr><td><strong>averages<strong></td></tr>')
    f.write('\n'.join(
        '<tr><td class="br">{input_name}</td><td>{cer_easy:.02%}</td><td>{cer:.02%}</td><td>{wer_easy:.02%}</td><td>{wer:.02%}</td><td>{mer:.02%}</td></tr>'
        .format(
            input_name=os.path.basename(input_path[i]),
            cer=metrics.nanmean(c, 'cer'),
            wer=metrics.nanmean(c, 'wer'),
            mer=metrics.nanmean(c, 'mer'),
            cer_easy=metrics.nanmean(c, 'words_easy_errors_easy.cer_pseudo'),
            wer_easy=metrics.nanmean(c, 'words_easy_errors_easy.wer_pseudo'),
        ) for i, c in enumerate(zip(*cat))))
    if len(cat_by_labels.keys()) > 1:
        for labels_name, labels_transcripts in cat_by_labels.items():
            f.write(
                f'<tr><td><strong>averages ({labels_name})<strong></td></tr>')
            f.write('\n'.join(
                '<tr><td class="br">{input_name}</td><td>{cer_easy:.02%}</td><td>{cer:.02%}</td><td>{wer_easy:.02%}</td><td>{wer:.02%}</td><td>{mer:.02%}</td></tr>'
                .format(
                    input_name=os.path.basename(input_path[i]),
                    cer=metrics.nanmean(c, 'cer'),
                    wer=metrics.nanmean(c, 'wer'),
                    mer=metrics.nanmean(c, 'mer_wordwise'),
                    cer_easy=metrics.nanmean(
                        c, 'words_easy_errors_easy.cer_pseudo'),
                    wer_easy=metrics.nanmean(
                        c, 'words_easy_errors_easy.wer_pseudo'),
                ) for i, c in enumerate(zip(*labels_transcripts))))
    f.write('<tr><td>&nbsp;</td></tr>')
    f.write('\n'.join(
        f'''<tr class="first"><td colspan="6">''' +
        (f'<audio controls src="{audio_data_uri(utt[0]["audio_path"][len(strip_audio_path_prefix):])}"></audio>'
         if audio else '') +
        f'<div class="nowrap">{utt[0]["audio_name"]}</div></td><td>{word_alignment(utt[0], ref = True, flat = True)}</td><td>{word_alignment(utt[0], ref = True, flat = True)}</td></tr>'
        + '\n'.join(
            '<tr class="any"><td class="br">{audio_name}</td><td>{cer_easy:.02%}</td><td>{cer:.02%}</td><td>{wer_easy:.02%}</td><td>{wer:.02%}</td><td class="br">{mer:.02%}</td><td>{word_alignment}</td><td>{word_alignment_flat}</td></tr>'
            .format(audio_name=transcripts.audio_name(input_path[i]),
                    cer_easy=a.get("words_easy_errors_easy", {}).get(
                        "cer_pseudo", -1),
                    cer=a.get("cer", 1),
                    wer_easy=a.get("words_easy_errors_easy", {}).get(
                        "wer_pseudo", -1),
                    wer=a.get("wer", 1),
                    mer=a.get("mer_wordwise", 1),
                    word_alignment=word_alignment(a.get('words', [])),
                    word_alignment_flat=word_alignment(a, hyp=True, flat=True))
            for i, a in enumerate(utt)) for utt in cat))
    f.write('</table></body></html>')
    print(html_path)
Ejemplo n.º 11
0
    def __init__(self,
                 data_paths,
                 labels,
                 sample_rate,
                 frontend=None,
                 speaker_names=None,
                 waveform_transform_debug_dir=None,
                 min_duration=None,
                 max_duration=None,
                 duration_filter=True,
                 min_ref_len=None,
                 max_ref_len=None,
                 max_num_channels=2,
                 ref_len_filter=True,
                 mono=True,
                 segmented=False,
                 time_padding_multiple=1,
                 audio_backend=None,
                 exclude=set(),
                 join_transcript=False,
                 bucket=None,
                 pop_meta=False,
                 string_array_encoding='utf_16_le',
                 _print=print):
        self.join_transcript = join_transcript
        self.max_duration = max_duration
        self.labels = labels
        self.frontend = frontend
        self.sample_rate = sample_rate
        self.waveform_transform_debug_dir = waveform_transform_debug_dir
        self.segmented = segmented
        self.time_padding_multiple = time_padding_multiple
        self.mono = mono
        self.audio_backend = audio_backend

        data_paths = data_paths if isinstance(data_paths,
                                              list) else [data_paths]
        exclude = set(exclude)

        def read_transcript(data_path):
            assert os.path.exists(data_path)
            if data_path.endswith('.json') or data_path.endswith('.json.gz'):
                return json.load(utils.open_maybe_gz(data_path))
            if os.path.exists(data_path + '.json'):
                return json.load(open(data_path + '.json'))
            return [dict(audio_path=data_path)]

        tic = time.time()

        transcripts_read = list(map(read_transcript, data_paths))
        _print('Dataset reading time: ',
               time.time() - tic)
        tic = time.time()

        segments_by_audio_path = [
            list(g) for transcript in transcripts_read
            for k, g in itertools.groupby(sorted(transcript,
                                                 key=transcripts.sort_key),
                                          key=transcripts.group_key)
        ]
        speaker_names_filtered = set()
        examples_filtered = []
        examples_lens = []
        transcript = []

        duration = lambda example: sum(
            map(transcripts.compute_duration, example))
        segments_by_audio_path.sort(key=duration)
        # TODO: not segmented mode may fail if several examples have same audio_path
        for example in segments_by_audio_path:
            exclude_ok = ((not exclude)
                          or (transcripts.audio_name(e[0]) not in exclude))
            duration_ok = (
                (not duration_filter) or
                (min_duration is None or min_duration <= duration(example)) and
                (max_duration is None or duration(example) <= max_duration))

            if duration_ok and exclude_ok:
                b = bucket(example) if bucket is not None else 0
                for t in example:
                    #t['meta'] = t.copy()
                    t['bucket'] = b
                    t['ref'] = t.get('ref', self.ref_missing)
                    t['begin'] = t.get('begin', self.time_missing)
                    t['end'] = t.get('end', self.time_missing)
                    t['channel'] = t.get('channel', self.channel_missing)

                examples_filtered.append(example)
                transcript.extend(example)
                speaker_names_filtered.update(
                    str(t['speaker']) for t in example if t.get('speaker'))
                examples_lens.append(len(example))

        if speaker_names:
            self.speaker_names = speaker_names
        else:
            speaker_names = list(sorted(speaker_names_filtered)) or [
                f'channel{1 + c}' for c in range(max_num_channels)
            ]
            self.speaker_names = [self.speaker_name_missing] + speaker_names
        self.speaker_names_index = {
            speaker_name: i
            for i, speaker_name in enumerate(self.speaker_names)
        }
        assert self.speaker_names_index.get(
            self.speaker_name_missing) == self.speaker_missing

        for t in transcript:
            t['speaker'] = t['speaker'] if isinstance(
                t.get('speaker'), int
            ) else self.speaker_names_index.get(
                t['speaker'], self.speaker_missing
            ) if isinstance(
                t.get('speaker'), str
            ) else 1 + t['channel'] if 'channel' in t else self.speaker_missing
            t['speaker_name'] = self.speaker_names[t['speaker']]

        _print('Dataset construction time: ',
               time.time() - tic)
        tic = time.time()

        self.bucket = torch.ShortTensor(
            [e[0]['bucket'] for e in examples_filtered])
        self.audio_path = TensorBackedStringArray(
            [e[0]['audio_path'] for e in examples_filtered],
            encoding=string_array_encoding)
        self.ref = TensorBackedStringArray([t['ref'] for t in transcript],
                                           encoding=string_array_encoding)
        self.begin = torch.FloatTensor([t['begin'] for t in transcript])
        self.end = torch.FloatTensor([t['end'] for t in transcript])
        self.channel = torch.CharTensor([t['channel'] for t in transcript])
        self.speaker = torch.LongTensor([t['speaker'] for t in transcript])
        self.cumlen = torch.ShortTensor(examples_lens).cumsum(
            dim=0, dtype=torch.int64)
        self.meta = {self.example_id(t): t
                     for t in transcript} if not pop_meta else {}
        _print('Dataset tensors creation time: ', time.time() - tic)