コード例 #1
0
                                                 target_lengths,
                                                 blank=0)
ce_ctc = -ce_alignment_targets * log_probs
ce_ctc_grad, = torch.autograd.grad(ce_ctc.sum(), logits, retain_graph=True)

print('Custom loss matches:', torch.allclose(builtin_ctc,
                                             custom_ctc,
                                             atol=atol))
print('Grad matches:',
      torch.allclose(builtin_ctc_grad, custom_ctc_grad, atol=atol))
print('CE grad matches:',
      torch.allclose(builtin_ctc_grad, ce_ctc_grad, atol=atol))

alignment = ctc.alignment(log_probs,
                          targets,
                          input_lengths,
                          target_lengths,
                          blank=0,
                          reduction='none')
a = alignment[:, 0, :target_lengths[0]]
plt.subplot(211)
plt.title('Input-Output Viterbi alignment')
plt.imshow(a.t().cpu(), origin='lower', aspect='auto')
plt.xlabel('Input steps')
plt.ylabel('Output steps')
plt.subplot(212)
plt.title('CTC alignment targets')
a = ce_alignment_targets[:, 0, :]
plt.imshow(a.t().cpu(), origin='lower', aspect='auto')
plt.xlabel('Input steps')
plt.ylabel(f'Output symbols, blank {blank}')
plt.subplots_adjust(hspace=0.5)
コード例 #2
0
def main(args, ext_json=['.json', '.json.gz']):
    utils.enable_jit_fusion()

    assert args.output_json or args.output_html or args.output_txt or args.output_csv, \
     'at least one of the output formats must be provided'
    os.makedirs(args.output_path, exist_ok=True)

    audio_data_paths = set(
        p for f in args.input_path
        for p in ([os.path.join(f, g)
                   for g in os.listdir(f)] if os.path.isdir(f) else [f])
        if os.path.isfile(p) and any(map(p.endswith, args.ext)))
    json_data_paths = set(
        p for p in args.input_path if any(map(p.endswith, ext_json))
        and not utils.strip_suffixes(p, ext_json) in audio_data_paths)

    data_paths = list(audio_data_paths | json_data_paths)

    exclude = set([
        os.path.splitext(basename)[0]
        for basename in os.listdir(args.output_path)
        if basename.endswith('.json')
    ]) if args.skip_processed else None

    data_paths = [
        path for path in data_paths
        if exclude is None or os.path.basename(path) not in exclude
    ]

    text_pipeline, frontend, model, generator = setup(args)
    val_dataset = datasets.AudioTextDataset(
        data_paths, [text_pipeline],
        args.sample_rate,
        frontend=frontend if not args.frontend_in_model else None,
        mono=args.mono,
        time_padding_multiple=args.batch_time_padding_multiple,
        audio_backend=args.audio_backend,
        exclude=exclude,
        max_duration=args.transcribe_first_n_sec,
        mode='batched_channels'
        if args.join_transcript else 'batched_transcript',
        string_array_encoding=args.dataset_string_array_encoding,
        debug_short_long_records_features_from_whole_normalized_signal=args.
        debug_short_long_records_features_from_whole_normalized_signal)
    print('Examples count: ', len(val_dataset))
    val_meta = val_dataset.pop_meta()
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=None,
        collate_fn=val_dataset.collate_fn,
        num_workers=args.num_workers)
    csv_sep = dict(tab='\t', comma=',')[args.csv_sep]
    csv_lines = []  # only used if args.output_csv is True

    oom_handler = utils.OomHandler(max_retries=args.oom_retries)
    for i, (meta, s, x, xlen, y, ylen) in enumerate(val_data_loader):
        print(f'Processing: {i}/{len(val_dataset)}')
        meta = [val_meta[t['example_id']] for t in meta]

        audio_path = meta[0]['audio_path']
        audio_name = transcripts.audio_name(audio_path)
        begin_end = [dict(begin=t['begin'], end=t['end']) for t in meta]
        begin = torch.tensor([t['begin'] for t in begin_end],
                             dtype=torch.float)
        end = torch.tensor([t['end'] for t in begin_end], dtype=torch.float)
        #TODO WARNING assumes frontend not in dataset
        if not args.frontend_in_model:
            print('\n' * 10 + 'WARNING\n' * 5)
            print(
                'transcribe.py assumes frontend in model, in other case time alignment was incorrect'
            )
            print('WARNING\n' * 5 + '\n')

        duration = x.shape[-1] / args.sample_rate
        channel = [t['channel'] for t in meta]
        speaker = [t['speaker'] for t in meta]
        speaker_name = [t['speaker_name'] for t in meta]

        if x.numel() == 0:
            print(f'Skipping empty [{audio_path}].')
            continue

        try:
            tic = time.time()
            y, ylen = y.to(args.device), ylen.to(args.device)
            log_probs, logits, olen = model(
                x.squeeze(1).to(args.device), xlen.to(args.device))

            print('Input:', audio_name)
            print('Input time steps:', log_probs.shape[-1],
                  '| target time steps:', y.shape[-1])
            print(
                'Time: audio {audio:.02f} sec | processing {processing:.02f} sec'
                .format(audio=sum(map(transcripts.compute_duration, meta)),
                        processing=time.time() - tic))

            ts: shaping.Bt = duration * torch.linspace(
                0, 1, steps=log_probs.shape[-1],
                device=log_probs.device).unsqueeze(0).expand(x.shape[0], -1)

            ref_segments = [[
                dict(channel=channel[i],
                     begin=begin_end[i]['begin'],
                     end=begin_end[i]['end'],
                     ref=text_pipeline.postprocess(
                         text_pipeline.preprocess(meta[i]['ref'])))
            ] for i in range(len(meta))]
            hyp_segments = [
                alternatives[0] for alternatives in generator.generate(
                    tokenizer=text_pipeline.tokenizer,
                    log_probs=log_probs,
                    begin=begin,
                    end=end,
                    output_lengths=olen,
                    time_stamps=ts,
                    segment_text_key='hyp',
                    segment_extra_info=[
                        dict(speaker=s, speaker_name=sn, channel=c)
                        for s, sn, c in zip(speaker, speaker_name, channel)
                    ])
            ]
            hyp_segments = [
                transcripts.map_text(text_pipeline.postprocess, hyp=hyp)
                for hyp in hyp_segments
            ]
            hyp, ref = '\n'.join(
                transcripts.join(hyp=h)
                for h in hyp_segments).strip(), '\n'.join(
                    transcripts.join(ref=r) for r in ref_segments).strip()
            if args.verbose:
                print('HYP:', hyp)
            print('CER: {cer:.02%}'.format(cer=metrics.cer(hyp=hyp, ref=ref)))

            tic_alignment = time.time()
            if args.align and y.numel() > 0:
                alignment: shaping.BY = ctc.alignment(
                    log_probs.permute(2, 0, 1),
                    y[:, 0, :],  # assumed that 0 channel is char labels
                    olen,
                    ylen[:, 0],
                    blank=text_pipeline.tokenizer.eps_id,
                    pack_backpointers=args.pack_backpointers)
                aligned_ts: shaping.Bt = ts.gather(1, alignment)

                ref_segments = [
                    alternatives[0] for alternatives in generator.generate(
                        tokenizer=text_pipeline.tokenizer,
                        log_probs=torch.nn.functional.one_hot(
                            y[:,
                              0, :], num_classes=log_probs.shape[1]).permute(
                                  0, 2, 1),
                        begin=begin,
                        end=end,
                        output_lengths=ylen,
                        time_stamps=aligned_ts,
                        segment_text_key='ref',
                        segment_extra_info=[
                            dict(speaker=s, speaker_name=sn, channel=c)
                            for s, sn, c in zip(speaker, speaker_name, channel)
                        ])
                ]
                ref_segments = [
                    transcripts.map_text(text_pipeline.postprocess, hyp=ref)
                    for ref in ref_segments
                ]
            oom_handler.reset()
        except:
            if oom_handler.try_recover(model.parameters()):
                print(f'Skipping {i} / {len(val_dataset)}')
                continue
            else:
                raise

        print('Alignment time: {:.02f} sec'.format(time.time() -
                                                   tic_alignment))

        ref_transcript, hyp_transcript = [
            sorted(transcripts.flatten(segments), key=transcripts.sort_key)
            for segments in [ref_segments, hyp_segments]
        ]

        if args.max_segment_duration:
            if ref:
                ref_segments = list(
                    transcripts.segment_by_time(ref_transcript,
                                                args.max_segment_duration))
                hyp_segments = list(
                    transcripts.segment_by_ref(hyp_transcript, ref_segments))
            else:
                hyp_segments = list(
                    transcripts.segment_by_time(hyp_transcript,
                                                args.max_segment_duration))
                ref_segments = [[] for _ in hyp_segments]

        #### HACK for diarization
        elif args.ref_transcript_path and args.join_transcript:
            audio_name_hack = audio_name.split('.')[0]
            #TODO: normalize ref field
            ref_segments = [[t] for t in sorted(transcripts.load(
                os.path.join(args.ref_transcript_path, audio_name_hack +
                             '.json')),
                                                key=transcripts.sort_key)]
            hyp_segments = list(
                transcripts.segment_by_ref(hyp_transcript,
                                           ref_segments,
                                           set_speaker=True,
                                           soft=False))
        #### END OF HACK

        has_ref = bool(transcripts.join(ref=transcripts.flatten(ref_segments)))

        transcript = []
        for hyp_transcript, ref_transcript in zip(hyp_segments, ref_segments):
            hyp, ref = transcripts.join(hyp=hyp_transcript), transcripts.join(
                ref=ref_transcript)

            transcript.append(
                dict(audio_path=audio_path,
                     ref=ref,
                     hyp=hyp,
                     speaker_name=transcripts.speaker_name(ref=ref_transcript,
                                                           hyp=hyp_transcript),
                     words=metrics.align_words(
                         *metrics.align_strings(hyp=hyp, ref=ref))
                     if args.align_words else [],
                     words_ref=ref_transcript,
                     words_hyp=hyp_transcript,
                     **transcripts.summary(hyp_transcript),
                     **(dict(cer=metrics.cer(hyp=hyp, ref=ref))
                        if has_ref else {})))

        transcripts.collect_speaker_names(transcript,
                                          set_speaker_data=True,
                                          num_speakers=2)

        filtered_transcript = list(
            transcripts.prune(transcript,
                              align_boundary_words=args.align_boundary_words,
                              cer=args.prune_cer,
                              duration=args.prune_duration,
                              gap=args.prune_gap,
                              allowed_unk_count=args.prune_unk,
                              num_speakers=args.prune_num_speakers))

        print('Filtered segments:', len(filtered_transcript), 'out of',
              len(transcript))

        if args.output_json:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.json')
            print(transcripts.save(transcript_path, filtered_transcript))

        if args.output_html:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.html')
            print(
                vis.transcript(transcript_path, args.sample_rate, args.mono,
                               transcript, filtered_transcript))

        if args.output_txt:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.txt')
            with open(transcript_path, 'w') as f:
                f.write(' '.join(t['hyp'].strip()
                                 for t in filtered_transcript))
            print(transcript_path)

        if args.output_csv:
            assert len({t['audio_path'] for t in filtered_transcript}) == 1
            audio_path = filtered_transcript[0]['audio_path']
            hyp = ' '.join(t['hyp'].strip() for t in filtered_transcript)
            begin = min(t['begin'] for t in filtered_transcript)
            end = max(t['end'] for t in filtered_transcript)
            csv_lines.append(
                csv_sep.join([audio_path, hyp,
                              str(begin),
                              str(end)]))

        if args.logits:
            logits_file_path = os.path.join(args.output_path,
                                            audio_name + '.pt')
            if args.logits_crop:
                begin_end = [
                    dict(
                        zip(['begin', 'end'], [
                            t['begin'] + c / float(o) * (t['end'] - t['begin'])
                            for c in args.logits_crop
                        ])) for o, t in zip(olen, begin_end)
                ]
                logits_crop = [slice(*args.logits_crop) for o in olen]
            else:
                logits_crop = [slice(int(o)) for o in olen]

            # TODO: filter ref / hyp by channel?
            torch.save([
                dict(audio_path=audio_path,
                     logits=l[..., logits_crop[i]],
                     **begin_end[i],
                     ref=ref,
                     hyp=hyp) for i, l in enumerate(logits.cpu())
            ], logits_file_path)
            print(logits_file_path)

        print('Done: {:.02f} sec\n'.format(time.time() - tic))

    if args.output_csv:
        transcript_path = os.path.join(args.output_path, 'transcripts.csv')
        with open(transcript_path, 'w') as f:
            f.write('\n'.join(csv_lines))
        print(transcript_path)
コード例 #3
0
def logits(lang, logits, audio_name = None, MAX_ENTROPY = 1.0):
	good_audio_name = set(map(str.strip, open(audio_name[0])) if os.path.exists(audio_name[0]) else audio_name) if audio_name is not None else []
	labels = datasets.Labels(datasets.Language(lang))
	decoder = decoders.GreedyDecoder()
	tick_params = lambda ax, labelsize = 2.5, length = 0, **kwargs: ax.tick_params(axis = 'both', which = 'both', labelsize = labelsize, length = length, **kwargs) or [ax.set_linewidth(0) for ax in ax.spines.values()]
	logits_path = logits + '.html'
	html = open(logits_path, 'w')
	html.write('<html><head>' + meta_charset + f'</head><body><script>{play_script}{onclick_img_script}</script>')

	for i, t in enumerate(torch.load(logits)):
		audio_path, logits = t['audio_path'], t['logits']
		words = t.get('words', [t])
		y = t.get('y', torch.zeros(1, 0, dtype = torch.long))
		begin = t.get('begin', '')
		end = t.get('end', '')
		audio_name = transcripts.audio_name(audio_path)
		extra_metrics = dict(cer = t['cer']) if 'cer' in t else {}

		if good_audio_name and audio_name not in good_audio_name:
			continue

		log_probs = F.log_softmax(logits, dim = 0)
		entropy = models.entropy(log_probs, dim = 0, sum = False)
		log_probs_ = F.log_softmax(logits[:-1], dim = 0)
		entropy_ = models.entropy(log_probs_, dim = 0, sum = False)
		margin = models.margin(log_probs, dim = 0)
		#energy = features.exp().sum(dim = 0)[::2]

		plt.figure(figsize = (6, 2))
		ax = plt.subplot(211)
		plt.imshow(logits, aspect = 'auto')
		plt.xlim(0, logits.shape[-1] - 1)
		#plt.yticks([])
		plt.axis('off')
		tick_params(plt.gca())
		#plt.subplots_adjust(left = 0, right = 1, bottom = 0.12, top = 0.95)

		plt.subplot(212, sharex = ax)
		prob_top1, prob_top2 = log_probs.exp().topk(2, dim = 0).values
		plt.hlines(1.0, 0, entropy.shape[-1] - 1, linewidth = 0.2)
		artist_prob_top1, = plt.plot(prob_top1, 'b', linewidth = 0.3)
		artist_prob_top2, = plt.plot(prob_top2, 'g', linewidth = 0.3)
		artist_entropy, = plt.plot(entropy, 'r', linewidth = 0.3)
		artist_entropy_, = plt.plot(entropy_, 'yellow', linewidth = 0.3)
		plt.legend([artist_entropy, artist_entropy_, artist_prob_top1, artist_prob_top2],
					['entropy', 'entropy, no blank', 'top1 prob', 'top2 prob'],
					loc = 1,
					fontsize = 'xx-small',
					frameon = False)
		
		for b, e, v in zip(*models.rle1d(entropy > MAX_ENTROPY)):
			if bool(v):
				plt.axvspan(int(b), int(e), color='red', alpha=0.2)

		plt.ylim(0, 3.0)
		plt.xlim(0, entropy.shape[-1] - 1)

		decoded = decoder.decode(log_probs.unsqueeze(0), K = 5)[0]
		xlabels = list(
			map(
				'\n'.join,
				zip(
					*[
						labels.decode(d, replace_blank = '.', replace_space = '_', replace_repeat = False, strip = False)
						for d in decoded
					]
				)
			)
		)
		plt.xticks(torch.arange(entropy.shape[-1]), xlabels, fontfamily = 'monospace')
		tick_params(plt.gca())

		if y.numel() > 0:
			alignment = ctc.alignment(
				log_probs.unsqueeze(0).permute(2, 0, 1),
				y.unsqueeze(0).long(),
				torch.LongTensor([log_probs.shape[-1]]),
				torch.LongTensor([len(y)]),
				blank = len(log_probs) - 1
			).squeeze(0)
			
			ax = plt.gca().secondary_xaxis('top')
			ref, ref_ = labels.decode(y.tolist(), replace_blank = '.', replace_space = '_', replace_repeat = False, strip = False), alignment
			ax.set_xticklabels(ref)
			ax.set_xticks(ref_)
			tick_params(ax, colors = 'red')

		#k = 0
		#for i, c in enumerate(ref + ' '):
		#	if c == ' ':
		#		plt.axvspan(ref_[k] - 1, ref_[i - 1] + 1, facecolor = 'gray', alpha = 0.2)
		#		k = i + 1

		plt.subplots_adjust(left = 0, right = 1, bottom = 0.12, top = 0.95)

		buf = io.BytesIO()
		plt.savefig(buf, format = 'jpg', dpi = 600)
		plt.close()

		html.write(f'<h4>{audio_name}')
		html.write(' | '.join('{k}: {v:.02f}' for k, v in extra_metrics.items()))
		html.write('</h4>')
		html.write(fmt_alignment(words))
		html.write('<img data-begin="{begin}" data-end="{end}" data-channel="{channel}" onclick="onclick_img(event)" style="width:100%" src="data:image/jpeg;base64,{encoded}"></img>\n'.format(channel = i, begin = begin, end = end, encoded = base64.b64encode(buf.getvalue()).decode()))
		html.write(fmt_audio(audio_path = audio_path, channel = i))
		html.write('<hr/>')
	html.write('</body></html>')
	return logits_path
コード例 #4
0
ファイル: transcribe.py プロジェクト: MayMeta/convasr
def main(args):
    utils.enable_jit_fusion()

    assert args.output_json or args.output_html or args.output_txt or args.output_csv, \
     'at least one of the output formats must be provided'
    os.makedirs(args.output_path, exist_ok=True)
    data_paths = [
        p for f in args.input_path
        for p in ([os.path.join(f, g)
                   for g in os.listdir(f)] if os.path.isdir(f) else [f])
        if os.path.isfile(p) and any(map(p.endswith, args.ext))
    ] + [
        p
        for p in args.input_path if any(map(p.endswith, ['.json', '.json.gz']))
    ]
    exclude = set([
        os.path.splitext(basename)[0]
        for basename in os.listdir(args.output_path)
        if basename.endswith('.json')
    ] if args.skip_processed else [])
    data_paths = [
        path for path in data_paths if os.path.basename(path) not in exclude
    ]

    labels, frontend, model, decoder = setup(args)
    val_dataset = datasets.AudioTextDataset(
        data_paths, [labels],
        args.sample_rate,
        frontend=None,
        segmented=True,
        mono=args.mono,
        time_padding_multiple=args.batch_time_padding_multiple,
        audio_backend=args.audio_backend,
        exclude=exclude,
        max_duration=args.transcribe_first_n_sec,
        join_transcript=args.join_transcript,
        string_array_encoding=args.dataset_string_array_encoding)
    num_examples = len(val_dataset)
    print('Examples count: ', num_examples)
    val_data_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=None,
        collate_fn=val_dataset.collate_fn,
        num_workers=args.num_workers)
    csv_sep = dict(tab='\t', comma=',')[args.csv_sep]
    output_lines = []  # only used if args.output_csv is True

    oom_handler = utils.OomHandler(max_retries=args.oom_retries)
    for i, (meta, s, x, xlen, y, ylen) in enumerate(val_data_loader):
        print(f'Processing: {i}/{num_examples}')

        meta = [val_dataset.meta.get(m['example_id']) for m in meta]
        audio_path = meta[0]['audio_path']

        if x.numel() == 0:
            print(f'Skipping empty [{audio_path}].')
            continue

        begin = meta[0]['begin']
        end = meta[0]['end']
        audio_name = transcripts.audio_name(audio_path)

        try:
            tic = time.time()
            y, ylen = y.to(args.device), ylen.to(args.device)
            log_probs, olen = model(
                x.squeeze(1).to(args.device), xlen.to(args.device))

            #speech = vad.detect_speech(x.squeeze(1), args.sample_rate, args.window_size, aggressiveness = args.vad, window_size_dilate = args.window_size_dilate)
            #speech = vad.upsample(speech, log_probs)
            #log_probs.masked_fill_(models.silence_space_mask(log_probs, speech, space_idx = labels.space_idx, blank_idx = labels.blank_idx), float('-inf'))

            decoded = decoder.decode(log_probs, olen)

            print('Input:', audio_name)
            print('Input time steps:', log_probs.shape[-1],
                  '| target time steps:', y.shape[-1])
            print(
                'Time: audio {audio:.02f} sec | processing {processing:.02f} sec'
                .format(audio=sum(
                    transcripts.compute_duration(t) for t in meta),
                        processing=time.time() - tic))

            ts = (x.shape[-1] / args.sample_rate) * torch.linspace(
                0, 1,
                steps=log_probs.shape[-1]).unsqueeze(0) + torch.FloatTensor(
                    [t['begin'] for t in meta]).unsqueeze(1)
            channel = [t['channel'] for t in meta]
            speaker = [t['speaker'] for t in meta]
            ref_segments = [[
                dict(channel=channel[i],
                     begin=meta[i]['begin'],
                     end=meta[i]['end'],
                     ref=labels.decode(y[i, 0, :ylen[i]].tolist()))
            ] for i in range(len(decoded))]
            hyp_segments = [
                labels.decode(decoded[i],
                              ts[i],
                              channel=channel[i],
                              replace_blank=True,
                              replace_blank_series=args.replace_blank_series,
                              replace_repeat=True,
                              replace_space=False,
                              speaker=speaker[i] if isinstance(
                                  speaker[i], str) else None)
                for i in range(len(decoded))
            ]

            ref, hyp = '\n'.join(
                transcripts.join(ref=r)
                for r in ref_segments).strip(), '\n'.join(
                    transcripts.join(hyp=h) for h in hyp_segments).strip()
            if args.verbose:
                print('HYP:', hyp)
            print('CER: {cer:.02%}'.format(cer=metrics.cer(hyp=hyp, ref=ref)))

            tic_alignment = time.time()
            if args.align and y.numel() > 0:
                #if ref_full:# and not ref:
                #	#assert len(set(t['channel'] for t in meta)) == 1 or all(t['type'] != 'channel' for t in meta)
                #	#TODO: add space at the end
                #	channel = torch.ByteTensor(channel).repeat_interleave(log_probs.shape[-1]).reshape(1, -1)
                #	ts = ts.reshape(1, -1)
                #	log_probs = log_probs.transpose(0, 1).unsqueeze(0).flatten(start_dim = -2)
                #	olen = torch.tensor([log_probs.shape[-1]], device = log_probs.device, dtype = torch.long)
                #	y = y_full[None, None, :].to(y.device)
                #	ylen = torch.tensor([[y.shape[-1]]], device = log_probs.device, dtype = torch.long)
                #	segments = [([], sum([h for r, h in segments], []))]

                alignment = ctc.alignment(
                    log_probs.permute(2, 0, 1),
                    y.squeeze(1),
                    olen,
                    ylen.squeeze(1),
                    blank=labels.blank_idx,
                    pack_backpointers=args.pack_backpointers)
                ref_segments = [
                    labels.decode(y[i, 0, :ylen[i]].tolist(),
                                  ts[i],
                                  alignment[i],
                                  channel=channel[i],
                                  speaker=speaker[i],
                                  key='ref',
                                  speakers=val_dataset.speakers)
                    for i in range(len(decoded))
                ]
            oom_handler.reset()
        except:
            if oom_handler.try_recover(model.parameters()):
                print(f'Skipping {i} / {num_examples}')
                continue
            else:
                raise

        print('Alignment time: {:.02f} sec'.format(time.time() -
                                                   tic_alignment))

        if args.max_segment_duration:
            ref_transcript, hyp_transcript = [
                list(sorted(sum(segments, []), key=transcripts.sort_key))
                for segments in [ref_segments, hyp_segments]
            ]
            if ref:
                ref_segments = list(
                    transcripts.segment(ref_transcript,
                                        args.max_segment_duration))
                hyp_segments = list(
                    transcripts.segment(hyp_transcript, ref_segments))
            else:
                hyp_segments = list(
                    transcripts.segment(hyp_transcript,
                                        args.max_segment_duration))
                ref_segments = [[] for _ in hyp_segments]

        transcript = [
            dict(audio_path=audio_path,
                 ref=ref,
                 hyp=hyp,
                 speaker=transcripts.speaker(ref=ref_transcript,
                                             hyp=hyp_transcript),
                 cer=metrics.cer(hyp=hyp, ref=ref),
                 words=metrics.align_words(hyp=hyp, ref=ref)[-1]
                 if args.align_words else [],
                 alignment=dict(ref=ref_transcript, hyp=hyp_transcript),
                 **transcripts.summary(hyp_transcript)) for ref_transcript,
            hyp_transcript in zip(ref_segments, hyp_segments)
            for ref, hyp in [(transcripts.join(ref=ref_transcript),
                              transcripts.join(hyp=hyp_transcript))]
        ]
        filtered_transcript = list(
            transcripts.prune(transcript,
                              align_boundary_words=args.align_boundary_words,
                              cer=args.cer,
                              duration=args.duration,
                              gap=args.gap,
                              unk=args.unk,
                              num_speakers=args.num_speakers))

        print('Filtered segments:', len(filtered_transcript), 'out of',
              len(transcript))

        if args.output_json:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.json')
            print(transcript_path)
            with open(transcript_path, 'w') as f:
                json.dump(filtered_transcript,
                          f,
                          ensure_ascii=False,
                          sort_keys=True,
                          indent=2)

        if args.output_html:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.html')
            print(transcript_path)
            vis.transcript(transcript_path, args.sample_rate, args.mono,
                           transcript, filtered_transcript)

        if args.output_txt:
            transcript_path = os.path.join(args.output_path,
                                           audio_name + '.txt')
            print(transcript_path)
            with open(transcript_path, 'w') as f:
                f.write(hyp)

        if args.output_csv:
            output_lines.append(
                csv_sep.join((audio_path, hyp, str(begin), str(end))) + '\n')

        print('Done: {:.02f} sec\n'.format(time.time() - tic))

    if args.output_csv:
        with open(os.path.join(args.output_path, 'transcripts.csv'), 'w') as f:
            f.writelines(output_lines)
コード例 #5
0
ファイル: vis.py プロジェクト: MayMeta/convasr
def logits(logits, audio_name, MAX_ENTROPY=1.0):
    good_audio_name = set(
        map(str.strip, open(audio_name[0])) if os.path.
        exists(audio_name[0]) else audio_name)
    labels = datasets.Labels(ru)
    decoder = decoders.GreedyDecoder()
    tick_params = lambda ax, labelsize=2.5, length=0, **kwargs: ax.tick_params(
        axis='both',
        which='both',
        labelsize=labelsize,
        length=length,
        **kwargs) or [ax.set_linewidth(0) for ax in ax.spines.values()]
    logits_path = logits + '.html'
    html = open(logits_path, 'w')
    html.write('''<html><meta charset="utf-8"/><body><script>
		function onclick_(evt)
		{
			const img = evt.target;
			const dim = img.getBoundingClientRect();
			const t = (evt.clientX - dim.left) / dim.width;
			const audio = img.nextSibling;
			audio.currentTime = t * audio.duration;
			audio.play();
		}
	</script>''')
    for r in torch.load(logits):
        logits = r['logits']
        if good_audio_name and r['audio_name'] not in good_audio_name:
            continue

        ref_aligned, hyp_aligned = r['alignment']['ref'], r['alignment']['hyp']

        log_probs = F.log_softmax(logits, dim=0)
        entropy = models.entropy(log_probs, dim=0, sum=False)
        log_probs_ = F.log_softmax(logits[:-1], dim=0)
        entropy_ = models.entropy(log_probs_, dim=0, sum=False)
        margin = models.margin(log_probs, dim=0)
        #energy = features.exp().sum(dim = 0)[::2]

        alignment = ctc.alignment(log_probs.unsqueeze(0).permute(2, 0, 1),
                                  r['y'].unsqueeze(0).long(),
                                  torch.LongTensor([log_probs.shape[-1]]),
                                  torch.LongTensor([len(r['y'])]),
                                  blank=len(log_probs) - 1).squeeze(0)

        plt.figure(figsize=(6, 2))

        prob_top1, prob_top2 = log_probs.exp().topk(2, dim=0).values
        plt.hlines(1.0, 0, entropy.shape[-1] - 1, linewidth=0.2)
        artist_prob_top1, = plt.plot(prob_top1, 'b', linewidth=0.3)
        artist_prob_top2, = plt.plot(prob_top2, 'g', linewidth=0.3)
        artist_entropy, = plt.plot(entropy, 'r', linewidth=0.3)
        artist_entropy_, = plt.plot(entropy_, 'yellow', linewidth=0.3)
        plt.legend([
            artist_entropy, artist_entropy_, artist_prob_top1, artist_prob_top2
        ], ['entropy', 'entropy, no blank', 'top1 prob', 'top2 prob'],
                   loc=1,
                   fontsize='xx-small',
                   frameon=False)
        bad = (entropy > MAX_ENTROPY).tolist()
        #runs = []
        #for i, b in enumerate(bad):
        #	if b:
        #		if not runs or not bad[i - 1]:
        #			runs.append([i, i])
        #		else:
        #			runs[-1][1] += 1
        #for begin, end in runs:
        #	plt.axvspan(begin, end, color='red', alpha=0.2)

        plt.ylim(0, 3.0)
        plt.xlim(0, entropy.shape[-1] - 1)

        decoded = decoder.decode(log_probs.unsqueeze(0), K=5)[0]
        xlabels = list(
            map(
                '\n'.join,
                zip(*[
                    labels.decode(d,
                                  replace_blank='.',
                                  replace_space='_',
                                  replace_repeat=False) for d in decoded
                ])))
        #xlabels_ = labels.decode(log_probs.argmax(dim = 0).tolist(), blank = '.', space = '_', replace2 = False)
        plt.xticks(torch.arange(entropy.shape[-1]),
                   xlabels,
                   fontfamily='monospace')
        tick_params(plt.gca())

        ax = plt.gca().secondary_xaxis('top')
        ref, ref_ = labels.decode(r['y'].tolist(),
                                  replace_blank='.',
                                  replace_space='_',
                                  replace_repeat=False), alignment
        ax.set_xticklabels(ref)
        ax.set_xticks(ref_)
        tick_params(ax, colors='red')

        #k = 0
        #for i, c in enumerate(ref + ' '):
        #	if c == ' ':
        #		plt.axvspan(ref_[k] - 1, ref_[i - 1] + 1, facecolor = 'gray', alpha = 0.2)
        #		k = i + 1

        plt.subplots_adjust(left=0, right=1, bottom=0.12, top=0.95)

        buf = io.BytesIO()
        plt.savefig(buf, format='jpg', dpi=600)
        plt.close()

        html.write('<h4>{audio_name} | cer: {cer:.02f}</h4>'.format(**r))
        html.write(word_alignment(r['words']))
        html.write(
            '<img onclick="onclick_(event)" style="width:100%" src="data:image/jpeg;base64,{encoded}"></img>'
            .format(encoded=base64.b64encode(buf.getvalue()).decode()))
        html.write(
            '<audio style="width:100%" controls src="{audio_data_uri(r["audio_path"])}"></audio><hr/>'
        )
    html.write('</body></html>')
    print('\n', logits_path)