target_lengths, blank=0) ce_ctc = -ce_alignment_targets * log_probs ce_ctc_grad, = torch.autograd.grad(ce_ctc.sum(), logits, retain_graph=True) print('Custom loss matches:', torch.allclose(builtin_ctc, custom_ctc, atol=atol)) print('Grad matches:', torch.allclose(builtin_ctc_grad, custom_ctc_grad, atol=atol)) print('CE grad matches:', torch.allclose(builtin_ctc_grad, ce_ctc_grad, atol=atol)) alignment = ctc.alignment(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='none') a = alignment[:, 0, :target_lengths[0]] plt.subplot(211) plt.title('Input-Output Viterbi alignment') plt.imshow(a.t().cpu(), origin='lower', aspect='auto') plt.xlabel('Input steps') plt.ylabel('Output steps') plt.subplot(212) plt.title('CTC alignment targets') a = ce_alignment_targets[:, 0, :] plt.imshow(a.t().cpu(), origin='lower', aspect='auto') plt.xlabel('Input steps') plt.ylabel(f'Output symbols, blank {blank}') plt.subplots_adjust(hspace=0.5)
def main(args, ext_json=['.json', '.json.gz']): utils.enable_jit_fusion() assert args.output_json or args.output_html or args.output_txt or args.output_csv, \ 'at least one of the output formats must be provided' os.makedirs(args.output_path, exist_ok=True) audio_data_paths = set( p for f in args.input_path for p in ([os.path.join(f, g) for g in os.listdir(f)] if os.path.isdir(f) else [f]) if os.path.isfile(p) and any(map(p.endswith, args.ext))) json_data_paths = set( p for p in args.input_path if any(map(p.endswith, ext_json)) and not utils.strip_suffixes(p, ext_json) in audio_data_paths) data_paths = list(audio_data_paths | json_data_paths) exclude = set([ os.path.splitext(basename)[0] for basename in os.listdir(args.output_path) if basename.endswith('.json') ]) if args.skip_processed else None data_paths = [ path for path in data_paths if exclude is None or os.path.basename(path) not in exclude ] text_pipeline, frontend, model, generator = setup(args) val_dataset = datasets.AudioTextDataset( data_paths, [text_pipeline], args.sample_rate, frontend=frontend if not args.frontend_in_model else None, mono=args.mono, time_padding_multiple=args.batch_time_padding_multiple, audio_backend=args.audio_backend, exclude=exclude, max_duration=args.transcribe_first_n_sec, mode='batched_channels' if args.join_transcript else 'batched_transcript', string_array_encoding=args.dataset_string_array_encoding, debug_short_long_records_features_from_whole_normalized_signal=args. debug_short_long_records_features_from_whole_normalized_signal) print('Examples count: ', len(val_dataset)) val_meta = val_dataset.pop_meta() val_data_loader = torch.utils.data.DataLoader( val_dataset, batch_size=None, collate_fn=val_dataset.collate_fn, num_workers=args.num_workers) csv_sep = dict(tab='\t', comma=',')[args.csv_sep] csv_lines = [] # only used if args.output_csv is True oom_handler = utils.OomHandler(max_retries=args.oom_retries) for i, (meta, s, x, xlen, y, ylen) in enumerate(val_data_loader): print(f'Processing: {i}/{len(val_dataset)}') meta = [val_meta[t['example_id']] for t in meta] audio_path = meta[0]['audio_path'] audio_name = transcripts.audio_name(audio_path) begin_end = [dict(begin=t['begin'], end=t['end']) for t in meta] begin = torch.tensor([t['begin'] for t in begin_end], dtype=torch.float) end = torch.tensor([t['end'] for t in begin_end], dtype=torch.float) #TODO WARNING assumes frontend not in dataset if not args.frontend_in_model: print('\n' * 10 + 'WARNING\n' * 5) print( 'transcribe.py assumes frontend in model, in other case time alignment was incorrect' ) print('WARNING\n' * 5 + '\n') duration = x.shape[-1] / args.sample_rate channel = [t['channel'] for t in meta] speaker = [t['speaker'] for t in meta] speaker_name = [t['speaker_name'] for t in meta] if x.numel() == 0: print(f'Skipping empty [{audio_path}].') continue try: tic = time.time() y, ylen = y.to(args.device), ylen.to(args.device) log_probs, logits, olen = model( x.squeeze(1).to(args.device), xlen.to(args.device)) print('Input:', audio_name) print('Input time steps:', log_probs.shape[-1], '| target time steps:', y.shape[-1]) print( 'Time: audio {audio:.02f} sec | processing {processing:.02f} sec' .format(audio=sum(map(transcripts.compute_duration, meta)), processing=time.time() - tic)) ts: shaping.Bt = duration * torch.linspace( 0, 1, steps=log_probs.shape[-1], device=log_probs.device).unsqueeze(0).expand(x.shape[0], -1) ref_segments = [[ dict(channel=channel[i], begin=begin_end[i]['begin'], end=begin_end[i]['end'], ref=text_pipeline.postprocess( text_pipeline.preprocess(meta[i]['ref']))) ] for i in range(len(meta))] hyp_segments = [ alternatives[0] for alternatives in generator.generate( tokenizer=text_pipeline.tokenizer, log_probs=log_probs, begin=begin, end=end, output_lengths=olen, time_stamps=ts, segment_text_key='hyp', segment_extra_info=[ dict(speaker=s, speaker_name=sn, channel=c) for s, sn, c in zip(speaker, speaker_name, channel) ]) ] hyp_segments = [ transcripts.map_text(text_pipeline.postprocess, hyp=hyp) for hyp in hyp_segments ] hyp, ref = '\n'.join( transcripts.join(hyp=h) for h in hyp_segments).strip(), '\n'.join( transcripts.join(ref=r) for r in ref_segments).strip() if args.verbose: print('HYP:', hyp) print('CER: {cer:.02%}'.format(cer=metrics.cer(hyp=hyp, ref=ref))) tic_alignment = time.time() if args.align and y.numel() > 0: alignment: shaping.BY = ctc.alignment( log_probs.permute(2, 0, 1), y[:, 0, :], # assumed that 0 channel is char labels olen, ylen[:, 0], blank=text_pipeline.tokenizer.eps_id, pack_backpointers=args.pack_backpointers) aligned_ts: shaping.Bt = ts.gather(1, alignment) ref_segments = [ alternatives[0] for alternatives in generator.generate( tokenizer=text_pipeline.tokenizer, log_probs=torch.nn.functional.one_hot( y[:, 0, :], num_classes=log_probs.shape[1]).permute( 0, 2, 1), begin=begin, end=end, output_lengths=ylen, time_stamps=aligned_ts, segment_text_key='ref', segment_extra_info=[ dict(speaker=s, speaker_name=sn, channel=c) for s, sn, c in zip(speaker, speaker_name, channel) ]) ] ref_segments = [ transcripts.map_text(text_pipeline.postprocess, hyp=ref) for ref in ref_segments ] oom_handler.reset() except: if oom_handler.try_recover(model.parameters()): print(f'Skipping {i} / {len(val_dataset)}') continue else: raise print('Alignment time: {:.02f} sec'.format(time.time() - tic_alignment)) ref_transcript, hyp_transcript = [ sorted(transcripts.flatten(segments), key=transcripts.sort_key) for segments in [ref_segments, hyp_segments] ] if args.max_segment_duration: if ref: ref_segments = list( transcripts.segment_by_time(ref_transcript, args.max_segment_duration)) hyp_segments = list( transcripts.segment_by_ref(hyp_transcript, ref_segments)) else: hyp_segments = list( transcripts.segment_by_time(hyp_transcript, args.max_segment_duration)) ref_segments = [[] for _ in hyp_segments] #### HACK for diarization elif args.ref_transcript_path and args.join_transcript: audio_name_hack = audio_name.split('.')[0] #TODO: normalize ref field ref_segments = [[t] for t in sorted(transcripts.load( os.path.join(args.ref_transcript_path, audio_name_hack + '.json')), key=transcripts.sort_key)] hyp_segments = list( transcripts.segment_by_ref(hyp_transcript, ref_segments, set_speaker=True, soft=False)) #### END OF HACK has_ref = bool(transcripts.join(ref=transcripts.flatten(ref_segments))) transcript = [] for hyp_transcript, ref_transcript in zip(hyp_segments, ref_segments): hyp, ref = transcripts.join(hyp=hyp_transcript), transcripts.join( ref=ref_transcript) transcript.append( dict(audio_path=audio_path, ref=ref, hyp=hyp, speaker_name=transcripts.speaker_name(ref=ref_transcript, hyp=hyp_transcript), words=metrics.align_words( *metrics.align_strings(hyp=hyp, ref=ref)) if args.align_words else [], words_ref=ref_transcript, words_hyp=hyp_transcript, **transcripts.summary(hyp_transcript), **(dict(cer=metrics.cer(hyp=hyp, ref=ref)) if has_ref else {}))) transcripts.collect_speaker_names(transcript, set_speaker_data=True, num_speakers=2) filtered_transcript = list( transcripts.prune(transcript, align_boundary_words=args.align_boundary_words, cer=args.prune_cer, duration=args.prune_duration, gap=args.prune_gap, allowed_unk_count=args.prune_unk, num_speakers=args.prune_num_speakers)) print('Filtered segments:', len(filtered_transcript), 'out of', len(transcript)) if args.output_json: transcript_path = os.path.join(args.output_path, audio_name + '.json') print(transcripts.save(transcript_path, filtered_transcript)) if args.output_html: transcript_path = os.path.join(args.output_path, audio_name + '.html') print( vis.transcript(transcript_path, args.sample_rate, args.mono, transcript, filtered_transcript)) if args.output_txt: transcript_path = os.path.join(args.output_path, audio_name + '.txt') with open(transcript_path, 'w') as f: f.write(' '.join(t['hyp'].strip() for t in filtered_transcript)) print(transcript_path) if args.output_csv: assert len({t['audio_path'] for t in filtered_transcript}) == 1 audio_path = filtered_transcript[0]['audio_path'] hyp = ' '.join(t['hyp'].strip() for t in filtered_transcript) begin = min(t['begin'] for t in filtered_transcript) end = max(t['end'] for t in filtered_transcript) csv_lines.append( csv_sep.join([audio_path, hyp, str(begin), str(end)])) if args.logits: logits_file_path = os.path.join(args.output_path, audio_name + '.pt') if args.logits_crop: begin_end = [ dict( zip(['begin', 'end'], [ t['begin'] + c / float(o) * (t['end'] - t['begin']) for c in args.logits_crop ])) for o, t in zip(olen, begin_end) ] logits_crop = [slice(*args.logits_crop) for o in olen] else: logits_crop = [slice(int(o)) for o in olen] # TODO: filter ref / hyp by channel? torch.save([ dict(audio_path=audio_path, logits=l[..., logits_crop[i]], **begin_end[i], ref=ref, hyp=hyp) for i, l in enumerate(logits.cpu()) ], logits_file_path) print(logits_file_path) print('Done: {:.02f} sec\n'.format(time.time() - tic)) if args.output_csv: transcript_path = os.path.join(args.output_path, 'transcripts.csv') with open(transcript_path, 'w') as f: f.write('\n'.join(csv_lines)) print(transcript_path)
def logits(lang, logits, audio_name = None, MAX_ENTROPY = 1.0): good_audio_name = set(map(str.strip, open(audio_name[0])) if os.path.exists(audio_name[0]) else audio_name) if audio_name is not None else [] labels = datasets.Labels(datasets.Language(lang)) decoder = decoders.GreedyDecoder() tick_params = lambda ax, labelsize = 2.5, length = 0, **kwargs: ax.tick_params(axis = 'both', which = 'both', labelsize = labelsize, length = length, **kwargs) or [ax.set_linewidth(0) for ax in ax.spines.values()] logits_path = logits + '.html' html = open(logits_path, 'w') html.write('<html><head>' + meta_charset + f'</head><body><script>{play_script}{onclick_img_script}</script>') for i, t in enumerate(torch.load(logits)): audio_path, logits = t['audio_path'], t['logits'] words = t.get('words', [t]) y = t.get('y', torch.zeros(1, 0, dtype = torch.long)) begin = t.get('begin', '') end = t.get('end', '') audio_name = transcripts.audio_name(audio_path) extra_metrics = dict(cer = t['cer']) if 'cer' in t else {} if good_audio_name and audio_name not in good_audio_name: continue log_probs = F.log_softmax(logits, dim = 0) entropy = models.entropy(log_probs, dim = 0, sum = False) log_probs_ = F.log_softmax(logits[:-1], dim = 0) entropy_ = models.entropy(log_probs_, dim = 0, sum = False) margin = models.margin(log_probs, dim = 0) #energy = features.exp().sum(dim = 0)[::2] plt.figure(figsize = (6, 2)) ax = plt.subplot(211) plt.imshow(logits, aspect = 'auto') plt.xlim(0, logits.shape[-1] - 1) #plt.yticks([]) plt.axis('off') tick_params(plt.gca()) #plt.subplots_adjust(left = 0, right = 1, bottom = 0.12, top = 0.95) plt.subplot(212, sharex = ax) prob_top1, prob_top2 = log_probs.exp().topk(2, dim = 0).values plt.hlines(1.0, 0, entropy.shape[-1] - 1, linewidth = 0.2) artist_prob_top1, = plt.plot(prob_top1, 'b', linewidth = 0.3) artist_prob_top2, = plt.plot(prob_top2, 'g', linewidth = 0.3) artist_entropy, = plt.plot(entropy, 'r', linewidth = 0.3) artist_entropy_, = plt.plot(entropy_, 'yellow', linewidth = 0.3) plt.legend([artist_entropy, artist_entropy_, artist_prob_top1, artist_prob_top2], ['entropy', 'entropy, no blank', 'top1 prob', 'top2 prob'], loc = 1, fontsize = 'xx-small', frameon = False) for b, e, v in zip(*models.rle1d(entropy > MAX_ENTROPY)): if bool(v): plt.axvspan(int(b), int(e), color='red', alpha=0.2) plt.ylim(0, 3.0) plt.xlim(0, entropy.shape[-1] - 1) decoded = decoder.decode(log_probs.unsqueeze(0), K = 5)[0] xlabels = list( map( '\n'.join, zip( *[ labels.decode(d, replace_blank = '.', replace_space = '_', replace_repeat = False, strip = False) for d in decoded ] ) ) ) plt.xticks(torch.arange(entropy.shape[-1]), xlabels, fontfamily = 'monospace') tick_params(plt.gca()) if y.numel() > 0: alignment = ctc.alignment( log_probs.unsqueeze(0).permute(2, 0, 1), y.unsqueeze(0).long(), torch.LongTensor([log_probs.shape[-1]]), torch.LongTensor([len(y)]), blank = len(log_probs) - 1 ).squeeze(0) ax = plt.gca().secondary_xaxis('top') ref, ref_ = labels.decode(y.tolist(), replace_blank = '.', replace_space = '_', replace_repeat = False, strip = False), alignment ax.set_xticklabels(ref) ax.set_xticks(ref_) tick_params(ax, colors = 'red') #k = 0 #for i, c in enumerate(ref + ' '): # if c == ' ': # plt.axvspan(ref_[k] - 1, ref_[i - 1] + 1, facecolor = 'gray', alpha = 0.2) # k = i + 1 plt.subplots_adjust(left = 0, right = 1, bottom = 0.12, top = 0.95) buf = io.BytesIO() plt.savefig(buf, format = 'jpg', dpi = 600) plt.close() html.write(f'<h4>{audio_name}') html.write(' | '.join('{k}: {v:.02f}' for k, v in extra_metrics.items())) html.write('</h4>') html.write(fmt_alignment(words)) html.write('<img data-begin="{begin}" data-end="{end}" data-channel="{channel}" onclick="onclick_img(event)" style="width:100%" src="data:image/jpeg;base64,{encoded}"></img>\n'.format(channel = i, begin = begin, end = end, encoded = base64.b64encode(buf.getvalue()).decode())) html.write(fmt_audio(audio_path = audio_path, channel = i)) html.write('<hr/>') html.write('</body></html>') return logits_path
def main(args): utils.enable_jit_fusion() assert args.output_json or args.output_html or args.output_txt or args.output_csv, \ 'at least one of the output formats must be provided' os.makedirs(args.output_path, exist_ok=True) data_paths = [ p for f in args.input_path for p in ([os.path.join(f, g) for g in os.listdir(f)] if os.path.isdir(f) else [f]) if os.path.isfile(p) and any(map(p.endswith, args.ext)) ] + [ p for p in args.input_path if any(map(p.endswith, ['.json', '.json.gz'])) ] exclude = set([ os.path.splitext(basename)[0] for basename in os.listdir(args.output_path) if basename.endswith('.json') ] if args.skip_processed else []) data_paths = [ path for path in data_paths if os.path.basename(path) not in exclude ] labels, frontend, model, decoder = setup(args) val_dataset = datasets.AudioTextDataset( data_paths, [labels], args.sample_rate, frontend=None, segmented=True, mono=args.mono, time_padding_multiple=args.batch_time_padding_multiple, audio_backend=args.audio_backend, exclude=exclude, max_duration=args.transcribe_first_n_sec, join_transcript=args.join_transcript, string_array_encoding=args.dataset_string_array_encoding) num_examples = len(val_dataset) print('Examples count: ', num_examples) val_data_loader = torch.utils.data.DataLoader( val_dataset, batch_size=None, collate_fn=val_dataset.collate_fn, num_workers=args.num_workers) csv_sep = dict(tab='\t', comma=',')[args.csv_sep] output_lines = [] # only used if args.output_csv is True oom_handler = utils.OomHandler(max_retries=args.oom_retries) for i, (meta, s, x, xlen, y, ylen) in enumerate(val_data_loader): print(f'Processing: {i}/{num_examples}') meta = [val_dataset.meta.get(m['example_id']) for m in meta] audio_path = meta[0]['audio_path'] if x.numel() == 0: print(f'Skipping empty [{audio_path}].') continue begin = meta[0]['begin'] end = meta[0]['end'] audio_name = transcripts.audio_name(audio_path) try: tic = time.time() y, ylen = y.to(args.device), ylen.to(args.device) log_probs, olen = model( x.squeeze(1).to(args.device), xlen.to(args.device)) #speech = vad.detect_speech(x.squeeze(1), args.sample_rate, args.window_size, aggressiveness = args.vad, window_size_dilate = args.window_size_dilate) #speech = vad.upsample(speech, log_probs) #log_probs.masked_fill_(models.silence_space_mask(log_probs, speech, space_idx = labels.space_idx, blank_idx = labels.blank_idx), float('-inf')) decoded = decoder.decode(log_probs, olen) print('Input:', audio_name) print('Input time steps:', log_probs.shape[-1], '| target time steps:', y.shape[-1]) print( 'Time: audio {audio:.02f} sec | processing {processing:.02f} sec' .format(audio=sum( transcripts.compute_duration(t) for t in meta), processing=time.time() - tic)) ts = (x.shape[-1] / args.sample_rate) * torch.linspace( 0, 1, steps=log_probs.shape[-1]).unsqueeze(0) + torch.FloatTensor( [t['begin'] for t in meta]).unsqueeze(1) channel = [t['channel'] for t in meta] speaker = [t['speaker'] for t in meta] ref_segments = [[ dict(channel=channel[i], begin=meta[i]['begin'], end=meta[i]['end'], ref=labels.decode(y[i, 0, :ylen[i]].tolist())) ] for i in range(len(decoded))] hyp_segments = [ labels.decode(decoded[i], ts[i], channel=channel[i], replace_blank=True, replace_blank_series=args.replace_blank_series, replace_repeat=True, replace_space=False, speaker=speaker[i] if isinstance( speaker[i], str) else None) for i in range(len(decoded)) ] ref, hyp = '\n'.join( transcripts.join(ref=r) for r in ref_segments).strip(), '\n'.join( transcripts.join(hyp=h) for h in hyp_segments).strip() if args.verbose: print('HYP:', hyp) print('CER: {cer:.02%}'.format(cer=metrics.cer(hyp=hyp, ref=ref))) tic_alignment = time.time() if args.align and y.numel() > 0: #if ref_full:# and not ref: # #assert len(set(t['channel'] for t in meta)) == 1 or all(t['type'] != 'channel' for t in meta) # #TODO: add space at the end # channel = torch.ByteTensor(channel).repeat_interleave(log_probs.shape[-1]).reshape(1, -1) # ts = ts.reshape(1, -1) # log_probs = log_probs.transpose(0, 1).unsqueeze(0).flatten(start_dim = -2) # olen = torch.tensor([log_probs.shape[-1]], device = log_probs.device, dtype = torch.long) # y = y_full[None, None, :].to(y.device) # ylen = torch.tensor([[y.shape[-1]]], device = log_probs.device, dtype = torch.long) # segments = [([], sum([h for r, h in segments], []))] alignment = ctc.alignment( log_probs.permute(2, 0, 1), y.squeeze(1), olen, ylen.squeeze(1), blank=labels.blank_idx, pack_backpointers=args.pack_backpointers) ref_segments = [ labels.decode(y[i, 0, :ylen[i]].tolist(), ts[i], alignment[i], channel=channel[i], speaker=speaker[i], key='ref', speakers=val_dataset.speakers) for i in range(len(decoded)) ] oom_handler.reset() except: if oom_handler.try_recover(model.parameters()): print(f'Skipping {i} / {num_examples}') continue else: raise print('Alignment time: {:.02f} sec'.format(time.time() - tic_alignment)) if args.max_segment_duration: ref_transcript, hyp_transcript = [ list(sorted(sum(segments, []), key=transcripts.sort_key)) for segments in [ref_segments, hyp_segments] ] if ref: ref_segments = list( transcripts.segment(ref_transcript, args.max_segment_duration)) hyp_segments = list( transcripts.segment(hyp_transcript, ref_segments)) else: hyp_segments = list( transcripts.segment(hyp_transcript, args.max_segment_duration)) ref_segments = [[] for _ in hyp_segments] transcript = [ dict(audio_path=audio_path, ref=ref, hyp=hyp, speaker=transcripts.speaker(ref=ref_transcript, hyp=hyp_transcript), cer=metrics.cer(hyp=hyp, ref=ref), words=metrics.align_words(hyp=hyp, ref=ref)[-1] if args.align_words else [], alignment=dict(ref=ref_transcript, hyp=hyp_transcript), **transcripts.summary(hyp_transcript)) for ref_transcript, hyp_transcript in zip(ref_segments, hyp_segments) for ref, hyp in [(transcripts.join(ref=ref_transcript), transcripts.join(hyp=hyp_transcript))] ] filtered_transcript = list( transcripts.prune(transcript, align_boundary_words=args.align_boundary_words, cer=args.cer, duration=args.duration, gap=args.gap, unk=args.unk, num_speakers=args.num_speakers)) print('Filtered segments:', len(filtered_transcript), 'out of', len(transcript)) if args.output_json: transcript_path = os.path.join(args.output_path, audio_name + '.json') print(transcript_path) with open(transcript_path, 'w') as f: json.dump(filtered_transcript, f, ensure_ascii=False, sort_keys=True, indent=2) if args.output_html: transcript_path = os.path.join(args.output_path, audio_name + '.html') print(transcript_path) vis.transcript(transcript_path, args.sample_rate, args.mono, transcript, filtered_transcript) if args.output_txt: transcript_path = os.path.join(args.output_path, audio_name + '.txt') print(transcript_path) with open(transcript_path, 'w') as f: f.write(hyp) if args.output_csv: output_lines.append( csv_sep.join((audio_path, hyp, str(begin), str(end))) + '\n') print('Done: {:.02f} sec\n'.format(time.time() - tic)) if args.output_csv: with open(os.path.join(args.output_path, 'transcripts.csv'), 'w') as f: f.writelines(output_lines)
def logits(logits, audio_name, MAX_ENTROPY=1.0): good_audio_name = set( map(str.strip, open(audio_name[0])) if os.path. exists(audio_name[0]) else audio_name) labels = datasets.Labels(ru) decoder = decoders.GreedyDecoder() tick_params = lambda ax, labelsize=2.5, length=0, **kwargs: ax.tick_params( axis='both', which='both', labelsize=labelsize, length=length, **kwargs) or [ax.set_linewidth(0) for ax in ax.spines.values()] logits_path = logits + '.html' html = open(logits_path, 'w') html.write('''<html><meta charset="utf-8"/><body><script> function onclick_(evt) { const img = evt.target; const dim = img.getBoundingClientRect(); const t = (evt.clientX - dim.left) / dim.width; const audio = img.nextSibling; audio.currentTime = t * audio.duration; audio.play(); } </script>''') for r in torch.load(logits): logits = r['logits'] if good_audio_name and r['audio_name'] not in good_audio_name: continue ref_aligned, hyp_aligned = r['alignment']['ref'], r['alignment']['hyp'] log_probs = F.log_softmax(logits, dim=0) entropy = models.entropy(log_probs, dim=0, sum=False) log_probs_ = F.log_softmax(logits[:-1], dim=0) entropy_ = models.entropy(log_probs_, dim=0, sum=False) margin = models.margin(log_probs, dim=0) #energy = features.exp().sum(dim = 0)[::2] alignment = ctc.alignment(log_probs.unsqueeze(0).permute(2, 0, 1), r['y'].unsqueeze(0).long(), torch.LongTensor([log_probs.shape[-1]]), torch.LongTensor([len(r['y'])]), blank=len(log_probs) - 1).squeeze(0) plt.figure(figsize=(6, 2)) prob_top1, prob_top2 = log_probs.exp().topk(2, dim=0).values plt.hlines(1.0, 0, entropy.shape[-1] - 1, linewidth=0.2) artist_prob_top1, = plt.plot(prob_top1, 'b', linewidth=0.3) artist_prob_top2, = plt.plot(prob_top2, 'g', linewidth=0.3) artist_entropy, = plt.plot(entropy, 'r', linewidth=0.3) artist_entropy_, = plt.plot(entropy_, 'yellow', linewidth=0.3) plt.legend([ artist_entropy, artist_entropy_, artist_prob_top1, artist_prob_top2 ], ['entropy', 'entropy, no blank', 'top1 prob', 'top2 prob'], loc=1, fontsize='xx-small', frameon=False) bad = (entropy > MAX_ENTROPY).tolist() #runs = [] #for i, b in enumerate(bad): # if b: # if not runs or not bad[i - 1]: # runs.append([i, i]) # else: # runs[-1][1] += 1 #for begin, end in runs: # plt.axvspan(begin, end, color='red', alpha=0.2) plt.ylim(0, 3.0) plt.xlim(0, entropy.shape[-1] - 1) decoded = decoder.decode(log_probs.unsqueeze(0), K=5)[0] xlabels = list( map( '\n'.join, zip(*[ labels.decode(d, replace_blank='.', replace_space='_', replace_repeat=False) for d in decoded ]))) #xlabels_ = labels.decode(log_probs.argmax(dim = 0).tolist(), blank = '.', space = '_', replace2 = False) plt.xticks(torch.arange(entropy.shape[-1]), xlabels, fontfamily='monospace') tick_params(plt.gca()) ax = plt.gca().secondary_xaxis('top') ref, ref_ = labels.decode(r['y'].tolist(), replace_blank='.', replace_space='_', replace_repeat=False), alignment ax.set_xticklabels(ref) ax.set_xticks(ref_) tick_params(ax, colors='red') #k = 0 #for i, c in enumerate(ref + ' '): # if c == ' ': # plt.axvspan(ref_[k] - 1, ref_[i - 1] + 1, facecolor = 'gray', alpha = 0.2) # k = i + 1 plt.subplots_adjust(left=0, right=1, bottom=0.12, top=0.95) buf = io.BytesIO() plt.savefig(buf, format='jpg', dpi=600) plt.close() html.write('<h4>{audio_name} | cer: {cer:.02f}</h4>'.format(**r)) html.write(word_alignment(r['words'])) html.write( '<img onclick="onclick_(event)" style="width:100%" src="data:image/jpeg;base64,{encoded}"></img>' .format(encoded=base64.b64encode(buf.getvalue()).decode())) html.write( '<audio style="width:100%" controls src="{audio_data_uri(r["audio_path"])}"></audio><hr/>' ) html.write('</body></html>') print('\n', logits_path)