示例#1
0
 def forward(self, signal, sample_rate, window_size=0.02, extra={}):
     assert sample_rate in [
         8_000, 16_000, 32_000, 48_000
     ] and signal.dtype == torch.int16 and window_size in [
         0.01, 0.02, 0.03
     ]
     frame_len = int(window_size * sample_rate)
     speech = torch.as_tensor([[
         len(chunk) == frame_len
         and self.vad.is_speech(bytearray(chunk.numpy()), sample_rate)
         for chunk in channel.split(frame_len)
     ] for channel in signal])
     transcript = [
         dict(begin=float(begin) * window_size,
              end=(float(begin) + float(duration)) * window_size,
              speaker=1 + channel,
              speaker_name=transcripts.default_speaker_names[1 + channel],
              **extra) for channel in range(len(signal))
         for begin, duration, mask in zip(*models.rle1d(speech[speaker]))
         if mask == 1
     ]
     return transcript
示例#2
0
def logits(lang, logits, audio_name = None, MAX_ENTROPY = 1.0):
	good_audio_name = set(map(str.strip, open(audio_name[0])) if os.path.exists(audio_name[0]) else audio_name) if audio_name is not None else []
	labels = datasets.Labels(datasets.Language(lang))
	decoder = decoders.GreedyDecoder()
	tick_params = lambda ax, labelsize = 2.5, length = 0, **kwargs: ax.tick_params(axis = 'both', which = 'both', labelsize = labelsize, length = length, **kwargs) or [ax.set_linewidth(0) for ax in ax.spines.values()]
	logits_path = logits + '.html'
	html = open(logits_path, 'w')
	html.write('<html><head>' + meta_charset + f'</head><body><script>{play_script}{onclick_img_script}</script>')

	for i, t in enumerate(torch.load(logits)):
		audio_path, logits = t['audio_path'], t['logits']
		words = t.get('words', [t])
		y = t.get('y', torch.zeros(1, 0, dtype = torch.long))
		begin = t.get('begin', '')
		end = t.get('end', '')
		audio_name = transcripts.audio_name(audio_path)
		extra_metrics = dict(cer = t['cer']) if 'cer' in t else {}

		if good_audio_name and audio_name not in good_audio_name:
			continue

		log_probs = F.log_softmax(logits, dim = 0)
		entropy = models.entropy(log_probs, dim = 0, sum = False)
		log_probs_ = F.log_softmax(logits[:-1], dim = 0)
		entropy_ = models.entropy(log_probs_, dim = 0, sum = False)
		margin = models.margin(log_probs, dim = 0)
		#energy = features.exp().sum(dim = 0)[::2]

		plt.figure(figsize = (6, 2))
		ax = plt.subplot(211)
		plt.imshow(logits, aspect = 'auto')
		plt.xlim(0, logits.shape[-1] - 1)
		#plt.yticks([])
		plt.axis('off')
		tick_params(plt.gca())
		#plt.subplots_adjust(left = 0, right = 1, bottom = 0.12, top = 0.95)

		plt.subplot(212, sharex = ax)
		prob_top1, prob_top2 = log_probs.exp().topk(2, dim = 0).values
		plt.hlines(1.0, 0, entropy.shape[-1] - 1, linewidth = 0.2)
		artist_prob_top1, = plt.plot(prob_top1, 'b', linewidth = 0.3)
		artist_prob_top2, = plt.plot(prob_top2, 'g', linewidth = 0.3)
		artist_entropy, = plt.plot(entropy, 'r', linewidth = 0.3)
		artist_entropy_, = plt.plot(entropy_, 'yellow', linewidth = 0.3)
		plt.legend([artist_entropy, artist_entropy_, artist_prob_top1, artist_prob_top2],
					['entropy', 'entropy, no blank', 'top1 prob', 'top2 prob'],
					loc = 1,
					fontsize = 'xx-small',
					frameon = False)
		
		for b, e, v in zip(*models.rle1d(entropy > MAX_ENTROPY)):
			if bool(v):
				plt.axvspan(int(b), int(e), color='red', alpha=0.2)

		plt.ylim(0, 3.0)
		plt.xlim(0, entropy.shape[-1] - 1)

		decoded = decoder.decode(log_probs.unsqueeze(0), K = 5)[0]
		xlabels = list(
			map(
				'\n'.join,
				zip(
					*[
						labels.decode(d, replace_blank = '.', replace_space = '_', replace_repeat = False, strip = False)
						for d in decoded
					]
				)
			)
		)
		plt.xticks(torch.arange(entropy.shape[-1]), xlabels, fontfamily = 'monospace')
		tick_params(plt.gca())

		if y.numel() > 0:
			alignment = ctc.alignment(
				log_probs.unsqueeze(0).permute(2, 0, 1),
				y.unsqueeze(0).long(),
				torch.LongTensor([log_probs.shape[-1]]),
				torch.LongTensor([len(y)]),
				blank = len(log_probs) - 1
			).squeeze(0)
			
			ax = plt.gca().secondary_xaxis('top')
			ref, ref_ = labels.decode(y.tolist(), replace_blank = '.', replace_space = '_', replace_repeat = False, strip = False), alignment
			ax.set_xticklabels(ref)
			ax.set_xticks(ref_)
			tick_params(ax, colors = 'red')

		#k = 0
		#for i, c in enumerate(ref + ' '):
		#	if c == ' ':
		#		plt.axvspan(ref_[k] - 1, ref_[i - 1] + 1, facecolor = 'gray', alpha = 0.2)
		#		k = i + 1

		plt.subplots_adjust(left = 0, right = 1, bottom = 0.12, top = 0.95)

		buf = io.BytesIO()
		plt.savefig(buf, format = 'jpg', dpi = 600)
		plt.close()

		html.write(f'<h4>{audio_name}')
		html.write(' | '.join('{k}: {v:.02f}' for k, v in extra_metrics.items()))
		html.write('</h4>')
		html.write(fmt_alignment(words))
		html.write('<img data-begin="{begin}" data-end="{end}" data-channel="{channel}" onclick="onclick_img(event)" style="width:100%" src="data:image/jpeg;base64,{encoded}"></img>\n'.format(channel = i, begin = begin, end = end, encoded = base64.b64encode(buf.getvalue()).decode()))
		html.write(fmt_audio(audio_path = audio_path, channel = i))
		html.write('<hr/>')
	html.write('</body></html>')
	return logits_path
示例#3
0
def ref(input_path, output_path, sample_rate, window_size, device,
        max_duration, debug_audio, html, ext):
    os.makedirs(output_path, exist_ok=True)
    audio_source = ([
        (input_path, audio_name) for audio_name in os.listdir(input_path)
    ] if os.path.isdir(input_path) else [(os.path.dirname(input_path),
                                          os.path.basename(input_path))])
    for i, (input_path, audio_name) in enumerate(audio_source):
        print(i, '/', len(audio_source), audio_name)
        audio_path = os.path.join(input_path, audio_name)
        noextname = audio_name[:-len(ext)]
        transcript_path = os.path.join(output_path, noextname + '.json')
        rttm_path = os.path.join(output_path, noextname + '.rttm')

        signal, sample_rate = audio.read_audio(audio_path,
                                               sample_rate=sample_rate,
                                               mono=False,
                                               dtype='float32',
                                               duration=max_duration)

        speaker_id_ref, speaker_id_ref_ = select_speaker(
            signal.to(device),
            silence_absolute_threshold=0.05,
            silence_relative_threshold=0.2,
            kernel_size_smooth_signal=128,
            kernel_size_smooth_speaker=4096,
            kernel_size_smooth_silence=4096)

        transcript = [
            dict(audio_path=audio_path,
                 begin=float(begin) / sample_rate,
                 end=(float(begin) + float(duration)) / sample_rate,
                 speaker=speaker,
                 speaker_name=transcripts.default_speaker_names[speaker])
            for speaker in range(1, len(speaker_id_ref_))
            for begin, duration, mask in zip(
                *models.rle1d(speaker_id_ref_[speaker])) if mask == 1
        ]

        #transcript = [dict(audio_path = audio_path, begin = float(begin) / sample_rate, end = (float(begin) + float(duration)) / sample_rate, speaker_name = str(int(speaker)), speaker = int(speaker)) for begin, duration, speaker in zip(*models.rle1d(speaker_id_ref.cpu()))]

        transcript_without_speaker_missing = [
            t for t in transcript
            if t['speaker'] != transcripts.speaker_missing
        ]
        transcripts.save(transcript_path, transcript_without_speaker_missing)
        print(transcript_path)

        transcripts.save(rttm_path, transcript_without_speaker_missing)
        print(rttm_path)

        if debug_audio:
            audio.write_audio(
                transcript_path + '.wav',
                torch.cat([
                    signal[..., :speaker_id_ref.shape[-1]],
                    convert_speaker_id(speaker_id_ref[..., :signal.shape[-1]],
                                       to_bipole=True).unsqueeze(0).cpu() *
                    0.5, speaker_id_ref_[..., :signal.shape[-1]].cpu() * 0.5
                ]),
                sample_rate,
                mono=False)
            print(transcript_path + '.wav')

        if html:
            html_path = os.path.join(output_path, audio_name + '.html')
            vis.transcript(html_path,
                           sample_rate=sample_rate,
                           mono=True,
                           transcript=transcript,
                           duration=max_duration)