def Recognize(self, req, ctx): assert req.config.encoding == pb2.RecognitionConfig.LINEAR16 signal, sample_rate = audio.read_audio( None, raw_bytes=req.audio.content, raw_sample_rate=req.config.sample_rate_hertz, raw_num_channels=req.config.audio_channel_count, dtype='int16', sample_rate=self.frontend.sample_rate, mono=True) x = signal logits, olen = self.model(x.to(self.device)) decoded = self.decoder.decode(logits, olen) ts = (x.shape[-1] / sample_rate) * torch.linspace( 0, 1, steps=logits.shape[-1]) transcript = self.labels.decode(decoded[0], ts) hyp = transcripts.join(hyp=transcript) mktime = lambda t: dict(seconds=int(t), nanos=int((t - int(t)) * 1e9)) return pb2.RecognizeResponse(results=[ dict(alternatives=[ dict(transcript=hyp, confidence=1.0, words=[ dict(word=t['hyp'], start_time=mktime(t['begin']), end_time=mktime(t['end']), speaker_tag=0) for t in transcript ]) ], channel_tag=1) ])
def cut_audio(output_path, sample_rate, mono, dilate, strip_prefix, audio_backend, add_sub_paths, audio_transcripts): audio_path_res = [] prev_audio_path = '' for t in audio_transcripts: audio_path = t['audio_path'] signal = audio.read_audio( audio_path, sample_rate, backend=audio_backend )[0] if audio_path != prev_audio_path else signal if signal.numel( ) == 0: # bug with empty audio files witch produce empty cut file print('Empty audio_path ', audio_path) return [] t['channel'] = 0 if len(signal) == 1 else None if mono else t.get( 'channel') segment = signal[ slice(t['channel'], 1 + t['channel']) if t['channel'] is not None else ..., int(max(t['begin'] - dilate, 0) * sample_rate):int((t['end'] + dilate) * sample_rate)] segment_file_name = os.path.basename( audio_path) + '.{channel}-{begin:.06f}-{end:.06f}.wav'.format(**t) digest = hashlib.md5(segment_file_name.encode('utf-8')).hexdigest() sub_path = [digest[-1:], digest[:2], segment_file_name ] if add_sub_paths else [segment_file_name] segment_path = os.path.join(output_path, *sub_path) os.makedirs(os.path.dirname(segment_path), exist_ok=True) audio.write_audio(segment_path, segment, sample_rate, mono=True) if strip_prefix: segment_path = segment_path[len(strip_prefix ):] if segment_path.startswith( strip_prefix) else segment_path t['audio_path'] = t['audio_path'][len(strip_prefix):] if t['audio_path'].startswith(strip_prefix) else \ t['audio_path'] t = dict(audio_path=segment_path, audio_name=os.path.basename(segment_path), channel=0 if len(signal) == 1 else None, begin=0.0, end=segment.shape[-1] / sample_rate, speaker=t.pop('speaker', None), ref=t.pop('ref', None), hyp=t.pop('hyp', None), cer=t.pop('cer', None), wer=t.pop('wer', None), alignment=t.pop('alignment', []), words=t.pop('words', []), meta=t) prev_audio_path = audio_path audio_path_res.append(t) return audio_path_res
def generate_utterances(audio_path: str, output_path: str, sample_rate: int, vad, utterance_duration: float, stride: int, min_utterance_score: float): signal, _ = audio.read_audio(audio_path, sample_rate=sample_rate, mono=False, dtype=vad.input_dtype, __array_wrap__=vad.input_type) speaker_masks = vad.detect(signal, allow_overlap=True) if vad.input_type == torch.tensor: speaker_masks = speaker_masks.cpu().numpy() utterance_duration = math.ceil(utterance_duration * sample_rate) assert utterance_duration % stride == 0 ## https://habr.com/ru/post/489734/#1d sliding_window = np.lib.stride_tricks.as_strided( speaker_masks, shape=( speaker_masks.shape[0], int((speaker_masks.shape[-1] - utterance_duration) / stride) + 1, utterance_duration, ), strides=(speaker_masks.strides[0], stride, 1)) n_samples_by_speaker = sliding_window.sum(-1) # speakers ratio in range [0;1] - silence ratio in range [0;1] utterance_scores = n_samples_by_speaker[1:].min(0) / ( n_samples_by_speaker[1:].max(0) + 1) - n_samples_by_speaker[0] / utterance_duration n = 0 audio_name, extension = os.path.splitext(os.path.basename(audio_path)) while utterance_scores.max() > min_utterance_score: i = np.argmax(utterance_scores) utterance = signal[:, i * stride:i * stride + utterance_duration] utterance_scores[max(0, i - int(utterance_duration / stride) + 1):i + int(utterance_duration / stride)] = 0.0 audio.write_audio(os.path.join(output_path, 'mix', f'{audio_name}.{n}{extension}'), utterance.T, sample_rate, mono=True) audio.write_audio(os.path.join(output_path, 'spk1', f'{audio_name}.{n}{extension}'), utterance[0:1, :].T, sample_rate, mono=True) audio.write_audio(os.path.join(output_path, 'spk2', f'{audio_name}.{n}{extension}'), utterance[1:2, :].T, sample_rate, mono=True) n += 1
def main(args): with open(args.diarization_dataset) as data_file: for line in tqdm(data_file): example = json.loads(line) mask = transcripts.intervals_to_mask(example.pop('intervals'), example['sample_rate'], example['duration']).numpy() path, ext = os.path.splitext(example['audio_path']) signal, sample_rate = audio.read_audio(path + '_mix' + ext, mono=True) speaker_1 = signal[:, :mask.shape[-1]] * mask[1, :signal.shape[-1]] audio.write_audio(path + '_s1' + ext, speaker_1.T, sample_rate) speaker_2 = signal[:, :mask.shape[-1]] * mask[2, :signal.shape[-1]] audio.write_audio(path + '_s2' + ext, speaker_2.T, sample_rate)
def generate_transcript(audio_path: str, sample_rate: int, vad, allow_overlap: bool): signal, _ = audio.read_audio(audio_path, sample_rate=sample_rate, mono=False, dtype=vad.input_dtype, __array_wrap__=vad.input_type) speaker_masks = vad.detect(signal, allow_overlap) if vad.input_type == torch.tensor: speaker_masks = speaker_masks.cpu().numpy() transcript = transcripts.mask_to_transcript(speaker_masks, sample_rate) return dict(audio_path=audio_path, audio_name=os.path.basename(audio_path), transcript=transcript, sample_rate=sample_rate, duration=signal.shape[-1] / sample_rate)
def parse_line(line): # Parse line from csv filename, sentence, duration = line.decode('ascii').split('\t') # Audio file wav_path = os.path.join(hyperparams.dataset_path, filename + '.wav') wave = audio.read_audio(wav_path, hyperparams.sample_rate) audio_length = wave.shape[0] / hyperparams.sample_rate # Calculate spectrum mel, linear = audio.spectrogram(hyperparams, wave) # Encode sentence tokens = text.encode(sentence) return mel.T, linear.T, tokens, np.int32( tokens.size), np.float32(audio_length)
def hyp(input_path, output_path, device, batch_size, html, ext, sample_rate, max_duration): os.makedirs(output_path, exist_ok=True) audio_source = ([ (input_path, audio_name) for audio_name in os.listdir(input_path) ] if os.path.isdir(input_path) else [(os.path.dirname(input_path), os.path.basename(input_path))]) model = PyannoteDiarizationModel(device=device, batch_size=batch_size) for i, (input_path, audio_name) in enumerate(audio_source): print(i, '/', len(audio_source), audio_name) audio_path = os.path.join(input_path, audio_name) noextname = audio_name[:-len(ext)] transcript_path = os.path.join(output_path, noextname + '.json') rttm_path = os.path.join(output_path, noextname + '.rttm') signal, sample_rate = audio.read_audio(audio_path, sample_rate=sample_rate, mono=True, dtype='float32', duration=max_duration) transcript = model(signal, sample_rate=sample_rate, extra=dict(audio_path=audio_path)) transcripts.collect_speaker_names(transcript, set_speaker_data=True) transcripts.save(transcript_path, transcript) print(transcript_path) transcripts.save(rttm_path, transcript) print(rttm_path) if html: html_path = os.path.join(output_path, audio_name + '.html') vis.transcript(html_path, sample_rate=sample_rate, mono=True, transcript=transcript, duration=max_duration)
def test_audio_conv(audio_path): # Audio file wav_path = os.path.join(hyperparams.dataset_path, filename + '.wav') wave = audio.read_audio(wav_path, hyperparams.sample_rate) audio_length = wave.shape[0] / hyperparams.sample_rate # Calculate spectrum mel, linear = audio.spectrogram(hyperparams, wave) #plt.imshow(mel) from_mel = audio.mel_to_linear(mel, (hyperparams.num_freq - 1) * 2, hyperparams.sample_rate, hyperparams.num_mels) plt.imshow(from_mel) plt.show() plt.imshow(linear) plt.show() signal = audio.reconstruct(hyperparams, linear) audio.write_audio('test.wav', signal, hyperparams.sample_rate) signal = audio.reconstruct(hyperparams, mel, from_mel=True) audio.write_audio('test_mel.wav', signal, hyperparams.sample_rate)
def __getitem__(self, index): waveform_transform_debug = ( lambda audio_path, sample_rate, signal: audio.write_audio( os.path.join(self.waveform_transform_debug_dir, os.path.basename(audio_path) + '.wav'), signal, sample_rate)) if self.waveform_transform_debug_dir else None audio_path = self.audio_path[index] transcript = self.load_example(index) signal, sample_rate = audio.read_audio( audio_path, sample_rate=self.sample_rate, mono=self.mono, backend=self.audio_backend, duration=self.max_duration, dtype=self.audio_dtype ) if self.frontend is None or self.frontend.read_audio else ( audio_path, self.sample_rate) #TODO: support forced mono even if transcript is given #TODO: subsample speaker labels according to features some_segments_have_not_begin_end = any( t['begin'] == transcripts.time_missing and t['end'] == transcripts.time_missing for t in transcript) some_segments_have_ref = any(bool(t['ref']) for t in transcript) replace_transcript = self.join_transcript or (not transcript) or ( some_segments_have_not_begin_end and some_segments_have_ref) if replace_transcript: assert len(signal) == 1, 'only mono supported for now' ref_full = [t['ref'] for t in transcript] speaker = torch.cat([ torch.full( (len(ref) + 1, ), t['speaker'], dtype=torch.int64).scatter_(0, torch.tensor(len(ref)), transcripts.speaker_missing) for t, ref in zip(transcript, ref_full) ])[:-1].unsqueeze(0) transcript = [ dict(audio_path=audio_path, ref=' '.join(ref_full), example_id=self.example_id(dict(audio_path=audio_path)), channel=0, begin_samples=0, end_samples=None) ] else: transcript = [ dict(audio_path=audio_path, ref=t['ref'], example_id=self.example_id(t), channel=channel, begin_samples=int(t['begin'] * sample_rate) if t['begin'] != transcripts.time_missing else 0, end_samples=1 + int(t['end'] * sample_rate) if t['end'] != transcripts.time_missing else signal.shape[1], speaker=t['speaker']) for t in sorted(transcript, key=transcripts.sort_key) for channel in ([t['channel']] if t['channel'] != transcripts. channel_missing else range(len(signal))) ] speaker = torch.LongTensor([t.pop('speaker') for t in transcript]).unsqueeze(-1) ## TODO check logic features = [] for t in transcript: channel = t.pop('channel') time_slice = slice(t.pop('begin_samples'), t.pop( 'end_samples')) # pop is required independent of segmented if self.segmented and not self.debug_short_long_records_features_from_whole_normalized_signal: segment = signal[None, channel, time_slice] else: segment = signal[ None, channel, :] # begin, end meta could be corrupted, thats why we dont use it here if self.frontend is not None: if self.debug_short_long_records_features_from_whole_normalized_signal: segment_features = self.frontend(segment) hop_length = self.frontend.hop_length segment_features = segment_features[:, :, time_slice.start // hop_length:time_slice. stop // hop_length] features.append(segment_features.squeeze(0)) else: features.append( self.frontend( segment, waveform_transform_debug=waveform_transform_debug). squeeze(0)) else: features.append(segment) targets = [] for pipeline in self.text_pipelines: encoded_transcripts = [] for t in transcript: processed = pipeline.preprocess(t['ref']) tokens = torch.tensor(pipeline.encode([processed])[0], dtype=torch.long, device='cpu') encoded_transcripts.append(tokens) targets.append(encoded_transcripts) # not batch mode if not self.segmented: transcript, speaker, features = transcript[0], speaker[ 0], features[0] targets = [target[0] for target in targets] return [transcript, speaker, features] + targets
def transcript(html_path, sample_rate, mono, transcript, filtered_transcript = [], duration = None, NA = 'N/A', default_channel = 0): if isinstance(transcript, str): transcript = json.load(open(transcript)) audio_path = transcript[0]['audio_path'] audio_name = transcripts.audio_name(audio_path) signal, sample_rate = audio.read_audio(audio_path, sample_rate = sample_rate, mono = mono, duration = duration) channel_or_default = lambda channel: default_channel if channel == transcripts.channel_missing else channel def fmt_link(ref = '', hyp = '', channel = default_channel, begin = transcripts.time_missing, end = transcripts.time_missing, speaker = transcripts.speaker_missing, i = '', j = '', audio_path = '', special_begin = 0, special_end = 1, **kwargs): span = ref in [special_begin, special_end] or begin == transcripts.time_missing or end == transcripts.time_missing tag_open = '<span' if span else f'<a onclick="return play(event, {channel_or_default(channel)}, {begin}, {end})"' tag_attr = f' title="channel{channel}. speaker{speaker}: {begin:.04f} - {end:.04f} | {i} - {j}" href="#" target="_blank">' tag_contents = (ref + hyp) if isinstance(ref, str) else (f'{begin:.02f}' if begin != transcripts.time_missing else NA) if ref == special_begin else (f'{end:.02f}' if end != transcripts.time_missing else NA) if ref == special_end else (f'{end - begin:.02f}' if begin != transcripts.time_missing and end != transcripts.time_missing else NA) tag_close = '</span>' if span else '</a>' return tag_open + tag_attr + tag_contents + tag_close fmt_words = lambda rh: ' '.join(fmt_link(**w) for w in rh) fmt_begin_end = 'data-begin="{begin}" data-end="{end}"'.format html = open(html_path, 'w') style = ' '.join(f'.speaker{i} {{background-color : {c}; }}' for i, c in enumerate(speaker_colors)) + ' '.join(f'.channel{i} {{background-color : {c}; }}' for i, c in enumerate(channel_colors)) + ' a {text-decoration: none;} .reference{opacity:0.4} .channel{margin:0px} .ok{background-color:green} .m0{margin:0px} .top{vertical-align:top}' html.write(f'<html><head>' + meta_charset + f'<style>{style}</style></head><body>') html.write(f'<script>{play_script}{onclick_svg_script}</script>') html.write( f'<div style="overflow:auto"><h4 style="float:left">{audio_name}</h4><h5 style="float:right">0.000000</h5></div>' ) html_speaker_barcode = fmt_svg_speaker_barcode(transcript, begin = 0.0, end = signal.shape[-1] / sample_rate) html.writelines( f'<figure class="m0"><figcaption><a href="#" download="channel{c}.{audio_name}" onclick="return download_audio(event, {c})">channel #{c}:</a></figcaption><audio ontimeupdate="update_span(ontimeupdate_(event), event)" onpause="onpause_(event)" id="audio{c}" style="width:100%" controls src="{uri_audio}"></audio>{html_speaker_barcode}</figure><hr />' for c, uri_audio in enumerate( audio_data_uri(signal[channel], sample_rate) for channel in ([0, 1] if len(signal) == 2 else []) + [...] ) ) html.write('<pre class="channel"><h3 class="channel0 channel">hyp #0:<span class="subtitle"></span></h3></pre>') html.write('<pre class="channel"><h3 class="channel0 reference channel">ref #0:<span class="subtitle"></span></h3></pre>') html.write('<pre class="channel" style="margin-top: 10px"><h3 class="channel1 channel">hyp #1:<span class="subtitle"></span></h3></pre>') html.write('<pre class="channel"><h3 class="channel1 reference channel">ref #1:<span class="subtitle"></span></h3></pre>') def fmt_th(): idx_th = '<th>#</th>' speaker_th = '<th>speaker</th>' begin_th = '<th>begin</th>' end_th = '<th>end</th>' duration_th = '<th>dur</th>' hyp_th = '<th style="width:50%">hyp</th>' ref_th = '<th style="width:50%">ref</th>' + begin_th + end_th + duration_th + '<th>cer</th>' return '<tr>' + idx_th + speaker_th + begin_th + end_th + duration_th + hyp_th + ref_th def fmt_tr(i, ok, t, words, hyp, ref, channel, speaker, speaker_name, cer): idx_td = f'''<td class="top {ok and 'ok'}">#{i}</td>''' speaker_td = f'<td class="speaker{speaker}" title="speaker{speaker}">{speaker_name}</td>' left_td = f'<td class="top">{fmt_link(0, **transcripts.summary(hyp, ij = True))}</td><td class="top">{fmt_link(1, **transcripts.summary(hyp, ij = True))}</td><td class="top">{fmt_link(2, **transcripts.summary(hyp, ij = True))}</td>' hyp_td = f'<td class="top hyp" data-channel="{channel}" data-speaker="{speaker}" {fmt_begin_end(**transcripts.summary(hyp, ij = True))}>{fmt_words(hyp)}{fmt_alignment(words, hyp = True, prefix = "", tag = "<template>")}</td>' ref_td = f'<td class="top reference ref" data-channel="{channel}" data-speaker="{speaker}" {fmt_begin_end(**transcripts.summary(ref, ij = True))}>{fmt_words(ref)}{fmt_alignment(words, ref = True, prefix = "", tag = "<template>")}</td>' right_td = f'<td class="top">{fmt_link(0, **transcripts.summary(ref, ij = True))}</td><td class="top">{fmt_link(1, **transcripts.summary(ref, ij = True))}</td><td class="top">{fmt_link(2, **transcripts.summary(ref, ij = True))}</td>' cer_td = f'<td class="top">' + (f'{cer:.2%}' if cer != transcripts._er_missing else NA) + '</td>' return f'<tr class="channel{channel} speaker{speaker}">' + idx_td + speaker_td + left_td + hyp_td + ref_td + right_td + cer_td + '</tr>\n' html.write('<hr/><table style="width:100%">') html.write(fmt_th()) html.writelines(fmt_tr(i, t in filtered_transcript, t, t.get('words', [t]), t.get('words_hyp', [t]), t.get('words_ref', [t]), t.get('channel', transcripts.channel_missing), t.get('speaker', transcripts.speaker_missing), t.get('speaker_name', 'speaker{}'.format(t.get('speaker', transcripts.speaker_missing))), t.get('cer', transcripts._er_missing)) for i, t in enumerate(transcripts.sort(transcript))) html.write(f'</tbody></table><script>{subtitle_script}</script></body></html>') return html_path
def ref(input_path, output_path, sample_rate, window_size, device, max_duration, debug_audio, html, ext): os.makedirs(output_path, exist_ok=True) audio_source = ([ (input_path, audio_name) for audio_name in os.listdir(input_path) ] if os.path.isdir(input_path) else [(os.path.dirname(input_path), os.path.basename(input_path))]) for i, (input_path, audio_name) in enumerate(audio_source): print(i, '/', len(audio_source), audio_name) audio_path = os.path.join(input_path, audio_name) noextname = audio_name[:-len(ext)] transcript_path = os.path.join(output_path, noextname + '.json') rttm_path = os.path.join(output_path, noextname + '.rttm') signal, sample_rate = audio.read_audio(audio_path, sample_rate=sample_rate, mono=False, dtype='float32', duration=max_duration) speaker_id_ref, speaker_id_ref_ = select_speaker( signal.to(device), silence_absolute_threshold=0.05, silence_relative_threshold=0.2, kernel_size_smooth_signal=128, kernel_size_smooth_speaker=4096, kernel_size_smooth_silence=4096) transcript = [ dict(audio_path=audio_path, begin=float(begin) / sample_rate, end=(float(begin) + float(duration)) / sample_rate, speaker=speaker, speaker_name=transcripts.default_speaker_names[speaker]) for speaker in range(1, len(speaker_id_ref_)) for begin, duration, mask in zip( *models.rle1d(speaker_id_ref_[speaker])) if mask == 1 ] #transcript = [dict(audio_path = audio_path, begin = float(begin) / sample_rate, end = (float(begin) + float(duration)) / sample_rate, speaker_name = str(int(speaker)), speaker = int(speaker)) for begin, duration, speaker in zip(*models.rle1d(speaker_id_ref.cpu()))] transcript_without_speaker_missing = [ t for t in transcript if t['speaker'] != transcripts.speaker_missing ] transcripts.save(transcript_path, transcript_without_speaker_missing) print(transcript_path) transcripts.save(rttm_path, transcript_without_speaker_missing) print(rttm_path) if debug_audio: audio.write_audio( transcript_path + '.wav', torch.cat([ signal[..., :speaker_id_ref.shape[-1]], convert_speaker_id(speaker_id_ref[..., :signal.shape[-1]], to_bipole=True).unsqueeze(0).cpu() * 0.5, speaker_id_ref_[..., :signal.shape[-1]].cpu() * 0.5 ]), sample_rate, mono=False) print(transcript_path + '.wav') if html: html_path = os.path.join(output_path, audio_name + '.html') vis.transcript(html_path, sample_rate=sample_rate, mono=True, transcript=transcript, duration=max_duration)
def __getitem__(self, index): transcript = self.unpack_transcript(index) ## signal shape here shaping.CT signal: shaping.CT sample_rate: int signal, sample_rate = audio.read_audio(transcript[0]['audio_path'], sample_rate=self.sample_rate, mono=self.mono, backend=self.audio_backend, duration=self.max_duration, dtype=self.audio_dtype) transcript = [t for t in transcript if t['channel'] < len(signal)] features = [] # slicing code in time and channel dimension for t in transcript: channel = t.pop('channel') time_slice = slice( int(t['begin'] * sample_rate) if t['begin'] != transcripts.time_missing else 0, 1 + int(t['end'] * sample_rate) if t['end'] != transcripts.time_missing else signal.shape[1]) # signal shaping.CT -> segment shaping.1T if self.mode == AudioTextDataset.DEFAULT_MODE: segment: shaping._T = signal[ None, channel, :] # begin, end meta could be corrupted, thats why we dont use it here else: segment = signal[None, channel, time_slice] if self.frontend is not None: # debug_short_long_records_features_from_whole_normalized_signal means apply frontend to whole signal instead of segment if self.debug_short_long_records_features_from_whole_normalized_signal: segment_features = self.frontend(segment) hop_length = self.frontend.hop_length segment_features = segment_features[:, :, time_slice.start // hop_length:time_slice. stop // hop_length] features.append(segment_features.squeeze(0)) else: features.append(self.frontend(segment).squeeze(0)) else: features.append(segment) # speaker alignment and ref processing code targets = [] speakers = [] for pipeline in self.text_pipelines: encoded_refs, aligned_speakers = AudioTextDataset.encode_transcript( transcript, pipeline) targets.append(encoded_refs) speakers.append(aligned_speakers) # replace speaker separators from ref for t in transcript: t['ref'] = t['ref'].replace(transcripts.speaker_phrase_separator, ' ') # speaker generated for all text pipelines, but for backward compatibility, only first pipeline speakers will be used speaker = speakers[0] # not batch mode if self.mode == AudioTextDataset.DEFAULT_MODE: transcript, speaker, features = transcript[0], speaker[ 0], features[0] targets = [target[0] for target in targets] return [transcript, speaker, features] + targets
import scipy.signal as sig import scipy.fftpack as fftp from scipy import log10 from scipy.cluster.vq import kmeans2 as kmeans import sys import grainstream as gs import audio as au import grouping as grp import generator as gen import interface import stats as st stats = st.stats() params = interface.parse_args() [source_audio, sample_rate, source_length] = au.read_audio(params.infile) params.grain_size = (sample_rate * params.grain_size_ms) / 1000 params.grain_spacing = (sample_rate * params.grain_spacing_ms) / 1000 [grain_groups, event_list, event_groups, features] = grp.group_events(source_audio, params) if params.debug > 0: stats.num_events = len(event_list) if params.mode == 'loop': streams = gen.group_loop(sample_rate, params, grain_groups, features, stats) elif params.mode == 'block': streams = gen.block_generator(sample_rate, params, grain_groups, features, stats) print "Mixing down.."
def transcript(html_path, sample_rate, mono, transcript, filtered_transcript=[]): if isinstance(transcript, str): transcript = json.load(open(transcript)) has_hyp = any(t.get('hyp') for t in transcript) has_ref = any(t.get('ref') for t in transcript) audio_path = transcript[0]['audio_path'] signal, sample_rate = audio.read_audio(audio_path, sample_rate=sample_rate, mono=mono) fmt_link = lambda ref='', hyp='', channel=0, begin=0, end=0, speaker='', i='', j='', audio_path='', **kwargs: ( f'<a onclick="return play({channel},{begin},{end})"' if ref not in [0, 1] else '<span' ) + f' title="#{channel}. {speaker}: {begin:.04f} - {end:.04f} | {i} - {j}" href="#" target="_blank">' + ( (ref + hyp) if isinstance(ref, str) else f'{begin:.02f}' if ref == 0 else f'{end:.02f}' if ref == 1 else f'{end - begin:.02f}') + ('</a>' if ref not in [0, 1] else '</span>') fmt_words = lambda rh: ' '.join(fmt_link(**w) for w in rh) fmt_begin_end = 'data-begin="{begin}" data-end="{end}"'.format html = open(html_path, 'w') html.write( '<html><head><meta charset="UTF-8"><style>a {text-decoration: none;} .channel0 .hyp{padding-right:150px} .channel1 .hyp{padding-left:150px} .ok{background-color:green} .m0{margin:0px} .top{vertical-align:top} .channel0{background-color:violet} .channel1{background-color:lightblue;' + ('display:none' if len(signal) == 1 else '') + '} .reference{opacity:0.4} .channel{margin:0px}</style></head><body>') html.write( f'<div style="overflow:auto"><h4 style="float:left">{os.path.basename(audio_path)}</h4><h5 style="float:right">0.000000</h5></div>' ) uri_speaker_barcode = base64.b64encode( speaker_barcode(transcript, begin=0.0, end=signal.shape[-1] / sample_rate)).decode() html.writelines( f'<figure class="m0"><figcaption>channel #{c}:</figcaption><audio ontimeupdate="ontimeupdate_(event)" onpause="onpause_(event)" id="audio{c}" style="width:100%" controls src="{uri_audio}"></audio><img onclick="onclick_(event)" src="data:image/jpeg;base64,{uri_speaker_barcode}" style="width:100%"></img></figure>' for c, uri_audio in enumerate( audio_data_uri(signal[channel], sample_rate) for channel in ([0, 1] if len(signal) == 2 else []) + [...])) html.write( f'<pre class="channel"><h3 class="channel0 channel">hyp #0:<span class="subtitle"></span></h3></pre><pre class="channel"><h3 class="channel0 reference channel">ref #0:<span class="subtitle"></span></h3></pre><pre class="channel" style="margin-top: 10px"><h3 class="channel1 channel">hyp #1:<span class="subtitle"></span></h3></pre><pre class="channel"><h3 class="channel1 reference channel">ref #1:<span class="subtitle"></span></h3></pre><hr/><table style="width:100%">' ) def format_th(has_hyp, has_ref): speaker_th = '<th>speaker</th>' begin_th = '<th>begin</th>' end_th = '<th>end</th>' duration_th = '<th>dur</th>' hyp_th = '<th style="width:50%">hyp</th>' if has_hyp else '' ref_th = '<th style="width:50%">ref</th>' + begin_th + end_th + duration_th + '<th>cer</th>' if has_ref else '' return '<tr>' + speaker_th + begin_th + end_th + duration_th + hyp_th + ref_th html.write(format_th(has_hyp, has_ref)) def format_tr(t, ok, has_hyp, has_ref, hyp, ref, channel, speaker): speaker_td = f'''<td class="top {ok and 'ok'}">{speaker}</td>''' begin_td = f'<td class="top">{fmt_link(0, **transcripts.summary(hyp, ij = True))}</td>' end_td = f'<td class="top">{fmt_link(1, **transcripts.summary(hyp, ij = True))}</td>' duration_td = f'<td class="top">{fmt_link(2, **transcripts.summary(hyp, ij = True))}</td>' hyp_td = '<td class="top hyp" data-channel="{c}" {fmt_begin_end(**transcripts.summary(hyp, ij = True))}>{fmt_words(hyp)}<template>{word_alignment(t.get("words", []), tag = "", hyp = True)}</template></td>' if has_hyp else '' ref_td = f'<td class="top reference ref" data-channel="{channel}" {fmt_begin_end(**transcripts.summary(ref, ij = True))}>{fmt_words(ref)}<template>{word_alignment(t.get("words", []), tag = "", ref = True)}</template></td><td class="top">{fmt_link(0, **transcripts.summary(ref, ij = True))}</td><td class="top">{fmt_link(1, **transcripts.summary(ref, ij = True))}</td><td class="top">{fmt_link(2, **transcripts.summary(ref, ij = True))}</td><td class="top">{t.get("cer", -1):.2%}</td>' if ( has_ref and ref) else ('<td></td>' * 5 if has_ref else '') return f'<tr class="channel{channel}">' + speaker_td + begin_td + end_td + duration_td + hyp_td + ref_td + '</tr>' html.writelines( format_tr(t, ok, has_hyp, has_ref, hyp, ref, channel, speaker) for t in transcripts.sort(transcript) for ok in [t in filtered_transcript] for channel, speaker, ref, hyp in [(t.get('channel', 0), t.get('speaker', 0), t.get('words_ref', [t]), t.get('words_hyp', [t]))]) html.write('''</tbody></table><script> function play(channel, begin, end) { Array.from(document.querySelectorAll('audio')).map(audio => audio.pause()); const audio = document.querySelector(`#audio${channel}`); audio.currentTime = begin; audio.dataset.endTime = end; audio.play(); return false; } function onclick_(evt) { const img = evt.target; const dim = img.getBoundingClientRect(); const t = (evt.clientX - dim.left) / dim.width; const audio = img.previousSibling; audio.currentTime = t * audio.duration; audio.play(); } function subtitle(segments, time, channel) { return (segments.find(([rh, c, b, e]) => c == channel && b <= time && time <= e ) || ['', channel, null, null])[0]; } function onpause_(evt) { evt.target.dataset.endTime = null; } function ontimeupdate_(evt) { const time = evt.target.currentTime, endtime = evt.target.dataset.endTime; if(endtime && time > endtime) return evt.target.pause(); document.querySelector('h5').innerText = time.toString(); const [spanhyp0, spanref0, spanhyp1, spanref1] = document.querySelectorAll('span.subtitle'); [spanhyp0.innerHTML, spanref0.innerHTML, spanhyp1.innerHTML, spanref1.innerHTML] = [subtitle(hyp_segments, time, 0), subtitle(ref_segments, time, 0), subtitle(hyp_segments, time, 1), subtitle(ref_segments, time, 1)]; } const make_segment = td => [td.querySelector('template').innerHTML, td.dataset.channel, td.dataset.begin, td.dataset.end]; const hyp_segments = Array.from(document.querySelectorAll('.hyp')).map(make_segment), ref_segments = Array.from(document.querySelectorAll('.ref')).map(make_segment); </script></body></html>''') print(html_path)
def main(): filename = 'top_two_explanation_indices_genPredictions_RUN2.txt' #output outFile=open(filename, 'w') # STEP 1: to work on all the genuine TP files and find the top two components that are maximally activated ! #fname = 'allGenIndexList_TP.txt' #with open(fname) as f: # genuine_top_indexes_correct = [int(line.strip().split(' ')[0]) for line in f] #print(genuine_top_indexes_correct) # Total count = 736. These are all genuine files sorted according to scores (highest positive), with threshold 0.5 (all files have score greater than 0.5) #file_idx_list = genuine_top_indexes_correct[0:len(genuine_top_indexes_correct)] #savePath = 'explanations_v2.0/top10GenuineCorrect/' #file_idx_list = spoofed_top10_indexes_correct #savePath = 'explanations/top10SpoofedCorrect/' #file_idx_list = genuine_top10_indexes_incorrect #savePath = 'explanations/top10Genuine_Incorrect/' #file_idx_list = spoofed_top10_indexes_incorrect #savePath = 'explanations/top10Spoofed_Incorrect/' # STEP2: # We find that component 4 (900-1200ms) seems to appear most (among 766 genuine TP audio files that got scores > 0.5). We have extracted the corresponding # list of audio files. Now we shall try to find similarity analysing some of these files with open('topExplanation_list_Genuine_TP.txt') as f: id4_fileList = [int(line.strip().split(' ')[0]) for line in f] file_idx_list = id4_fileList #[0:10] savePath = 'explanations_v2.0/truePositive_Gen_736files_id4/' num_feat = 2 # number of components per explanation , e.g., select top 1 component out of the given 10. n_samples = 5000 # number of synthetic samples generated by SLIME top_lab = 2 # tells in a multi-class scenario how many labels to be explained: we can keep it fixed label_idx = 0 # tells what class needs to be explained, 0 - genuine, 1 - spoofed makeDirectory(savePath) runs = 5 # how many runs you want to run the same setup, to ensure predictions are correct sampling_rate = 16000 hop_size = 160 pow_spects_file = 'dev_spec.npz' mean_std_file = 'mean_std.npz' model_path = 'model_3sec_relu_0.5_run8' dataType = 'test' trainSize = 3 init_type = 'xavier' fftSize = 256 f = 129 t=300 padding = 'SAME' targets= 2 act = tf.nn.relu plot = False debug_prints = True # Returns a list where each element is a normalized power spectrogram per file. print('Loading spectrograms...') norm_pow_spects = audio.read_audio(pow_spects_file, mean_std_file) data = norm_pow_spects ''' if debug_prints: print('Input spectrogram shape: (%d, %d)' %(norm_pow_spects[file_idx].shape)) # file_idx-1 if plot: # figure 1 print('Plotting spectrogram for file number %d' %(file_idx + 1)) plt.figure(1) plt.subplot(1, 1, 1) disp.specshow(norm_pow_spects[file_idx].T, sr = sampling_rate, y_axis = 'linear', x_axis = 'time', hop_length = hop_size, cmap = 'coolwarm') plt.title('Input spectrogram') plt.colorbar() plt.show() ''' print('Reset TF graph and load trained models and session..') tf.reset_default_graph() input_data = tf.placeholder(tf.float32, [None, t, f,1]) # Placeholders for droput probability keep_prob1 = tf.placeholder(tf.float32) keep_prob2 = tf.placeholder(tf.float32) keep_prob3 = tf.placeholder(tf.float32) # Get model architecture that was used during training featureExtractor, model_prediction, network_weights, activations, biases= nn_architecture.cnnModel3(dataType, trainSize, input_data, act, init_type, targets, fftSize, padding, keep_prob1, keep_prob2, keep_prob3) modelScore = computeModelScore(model_prediction, apply_softmax=True) #Load trained session and model parameters sess, saver = load_model(model_path) print('Model parameters loaded succesfully !!') for file_idx in file_idx_list: #print('Generating explanations for Run: ', run+1) print('Generating explanations for File index: ', file_idx+1) explainer = lime_image.LimeImageExplainer(verbose=True) t1 = time.time() # generation explanation explanation, seg = explainer.explain_instance(data[file_idx].reshape(t,f), prediction_fn , sess, modelScore, input_data, keep_prob1, keep_prob2, keep_prob3, hide_color=0, top_labels=top_lab, num_samples=n_samples) print ("time taken for explanation generation:%f" %(time.time() - t1)) # extracting the information from the generated explanation # temp : masked spectrogram # mask : binary mask, 1: presence of the component, 0: absence of the component # fs : a tuple of two objects, one is a list of enabled components, other is the prediction error # I am also printing the weights assigned to each component. It will be very useful in analysis, something I didn't do in SLIME paper. # At the moment, I am only returning component index that positively influences a prediction. temp, mask, fs = explanation.get_image_and_mask(label_idx, positive_only=True, hide_rest=True, num_features=num_feat) #print(temp) #print('mask: ', mask) #print(type(fs)) print('*****************************') #outFile.write(str(fs[0][0])+ ' ' + str(fs[0][1])+ '\n') #print(fs[1]) # plotting the results plt.figure(2) plt.subplot(2,1,1) disp.specshow(normalise((data[file_idx].reshape(t, f))).T, y_axis= 'linear', x_axis='off', sr=sampling_rate, hop_length=hop_size, cmap = 'coolwarm') plt.title('Input Spectrogram') plt.subplot(2,1,2) disp.specshow(normalise(temp).T, y_axis= 'linear', x_axis='time', sr=sampling_rate, hop_length=hop_size, cmap = 'coolwarm') plt.title('Top %d explanations generated by LIME' %num_feat) #filename = savePath + 'file_'+str(file_idx+1) + '_run' + str(run+1) + '.png' # during final plots change the file extension type to .pdf and keep dpi = 300 filename = savePath + 'file_'+str(file_idx+1) + '.png' # during final plots change the file extension type to .pdf and keep dpi = 300 plt.savefig(filename) plt.close()
def __getitem__(self, index): waveform_transform_debug = ( lambda audio_path, sample_rate, signal: audio.write_audio( os.path.join(self.waveform_transform_debug_dir, os.path.basename(audio_path) + '.wav'), signal, sample_rate)) if self.waveform_transform_debug_dir else None audio_path = self.audio_path[index] transcript = self.load_example(index) signal, sample_rate = audio.read_audio( audio_path, sample_rate=self.sample_rate, mono=self.mono, backend=self.audio_backend, duration=self.max_duration ) if self.frontend is None or self.frontend.read_audio else ( audio_path, self.sample_rate) #TODO: support forced mono even if transcript is given #TODO: subsample speaker labels according to features some_segments_have_not_begin_end = any( t['begin'] == self.time_missing and t['end'] == self.time_missing for t in transcript) some_segments_have_ref = any(bool(t['ref']) for t in transcript) replace_transcript = self.join_transcript or (not transcript) or ( some_segments_have_not_begin_end and some_segments_have_ref) if replace_transcript: assert len(signal) == 1, 'only mono supported for now' # replacing ref by normalizing only with default preprocessor ref_full = [ self.labels[0].normalize_text(t['ref']) for t in transcript ] speaker = torch.cat([ torch.full((len(ref) + 1, ), t['speaker'], dtype=torch.int64).scatter_(0, torch.tensor(len(ref)), self.speaker_missing) for t, ref in zip(transcript, ref_full) ])[:-1] transcript = [ dict(audio_path=audio_path, ref=' '.join(ref_full), example_id=self.example_id(dict(audio_path=audio_path)), channel=0, begin_samples=0, end_samples=None) ] normalize_text = False else: transcript = [ dict(audio_path=audio_path, ref=t['ref'], example_id=self.example_id(t), channel=channel, begin_samples=int(t['begin'] * sample_rate) if t['begin'] != self.time_missing else 0, end_samples=1 + int(t['end'] * sample_rate) if t['end'] != self.time_missing else signal.shape[1], speaker=t['speaker']) for t in sorted(transcript, key=transcripts.sort_key) for channel in ([t['channel']] if t['channel'] != self. channel_missing else range(len(signal))) ] speaker = torch.LongTensor([t.pop('speaker') for t in transcript]).unsqueeze(-1) normalize_text = True features = [ self.frontend(segment.unsqueeze(0), waveform_transform_debug=waveform_transform_debug) if self.frontend is not None else segment.unsqueeze(0) for t in transcript for segment in [ signal[t.pop('channel'), t.pop('begin_samples'):t.pop('end_samples')] ] ] targets = [[ labels.encode(t['ref'], normalize=normalize_text)[1] for t in transcript ] for labels in self.labels] # not batch mode if not self.segmented: transcript, speaker, features = transcript[0], speaker[ 0], features[0][0] targets = [target[0] for target in targets] return [transcript, speaker, features] + targets