def _compute_mmi_loss_exact_non_optimized( nnet_output: torch.Tensor, texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' See :func:`_compute_mmi_loss_exact_optimized` for the meaning of the arguments. It's more readable, though it invokes k2.intersect_dense twice. Note: It uses less memory at the cost of speed. It is slower. ''' num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0) den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=10.0) num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames
def _compute_mmi_loss_pruned( nnet_output: torch.Tensor, texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' See :func:`_compute_mmi_loss_exact_optimized` for the meaning of the arguments. `pruned` means it uses k2.intersect_dense_pruned Note: It uses the least amount of memory, but the loss is not exact due to pruning. ''' num_graphs, den_graphs = graph_compiler.compile(texts, P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0) # the values for search_beam/output_beam/min_active_states/max_active_states # are not tuned. You may want to tune them. den_lats = k2.intersect_dense_pruned(den_graphs, dense_fsa_vec, search_beam=20.0, output_beam=7.0, min_active_states=30, max_active_states=10000) num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames
def _compute_mmi_loss_exact_optimized( nnet_output: torch.Tensor, texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' The function name contains `exact`, which means it uses a version of intersection without pruning. `optimized` in the function name means this function is optimized in that it calls k2.intersect_dense only once Note: It is faster at the cost of using more memory. Args: nnet_output: A 3-D tensor of shape [N, T, C] texts: The transcript. Each element consists of space(s) separated words. supervision_segments: A 2-D tensor that will be passed to :func:`k2.DenseFsaVec`. graph_compiler: Used to build num_graphs and den_graphs P: Represents a bigram Fsa. den_scale: The scale applied to the denominator tot_scores. ''' num_graphs, den_graphs = graph_compiler.compile(texts, P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) device = num_graphs.device num_fsas = num_graphs.shape[0] assert dense_fsa_vec.dim0() == num_fsas assert den_graphs.shape[0] == 1 # the aux_labels of num_graphs is k2.RaggedInt # but it is torch.Tensor for den_graphs. # # The following converts den_graphs.aux_labels # from torch.Tensor to k2.RaggedInt so that # we can use k2.append() later den_graphs.convert_attr_to_ragged_(name='aux_labels') # The motivation to concatenate num_graphs and den_graphs # is to reduce the number of calls to k2.intersect_dense. num_den_graphs = k2.cat([num_graphs, den_graphs]) # NOTE: The a_to_b_map in k2.intersect_dense must be sorted # so the following reorders num_den_graphs. # # The following code computes a_to_b_map # [0, 1, 2, ... ] num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32) # [num_fsas, num_fsas, num_fsas, ... ] den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32) # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] num_den_graphs_indexes = torch.stack( [num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device) num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes) # [[0, 1, 2, ...]] a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1) # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ] a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device) num_den_lats = k2.intersect_dense(num_den_reordered_graphs, dense_fsa_vec, output_beam=10.0, a_to_b_map=a_to_b_map) num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) num_tot_scores = num_den_tot_scores[::2] den_tot_scores = num_den_tot_scores[1::2] tot_scores = num_tot_scores - den_scale * den_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames
class ASR: """ This class is a high-level wrapper for K2 acoustic models that simplifies inference: reading models, computing posteriors, decoding, alignments, etc. Currently it will only work with the Conformer model with a very specific HMM topology. It could be the basis for a more generic entry point to Snow(Ice?)fall. """ def __init__( self, lang_dir: Pathlike, scripted_model_path: Optional[Pathlike] = None, model_dir: Optional[Pathlike] = None, average_epochs: Sequence[int] = (7, 8, 9), device: torch.device = 'cpu', sampling_rate: int = 16000, ): if isinstance(device, str): self.device = torch.device(device) self.sampling_rate = sampling_rate self.extractor = Fbank(FbankConfig(num_mel_bins=80)) self.lexicon = Lexicon(lang_dir) phone_ids = self.lexicon.phone_symbols() self.P = create_bigram_phone_lm(phone_ids) if model_dir is not None: # Read model from regular checkpoints, assume it's a Conformer self.model = Conformer(num_features=80, num_classes=len(phone_ids) + 1, num_decoder_layers=0) self.P.scores = torch.zeros_like(self.P.scores) self.model.P_scores = torch.nn.Parameter(self.P.scores.clone(), requires_grad=False) average_checkpoint(filenames=[ model_dir / f'epoch-{n}.pt' for n in average_epochs ], model=self.model) elif scripted_model_path is not None: # Read model from a serialized TorchScript module, no assumptions needed self.model = torch.jit.load(scripted_model_path) else: raise ValueError( "One of scripted_model_path or model_dir needs to be provided." ) # Freeze the params by default. for p in self.model.parameters(): p.requires_grad_(False) self.compiler = MmiTrainingGraphCompiler(lexicon=self.lexicon, device=self.device) self.HLG = k2.Fsa.from_dict(torch.load(lang_dir / 'HLG.pt')).to( self.device) def compute_features(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor: if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) assert cuts[ 0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}' otf = OnTheFlyFeatures(self.extractor) # feats: (batch, seq_len, n_feats) feats, _ = otf(cuts) return feats def compute_posteriors(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor: """ Run the forward pass of the acoustic model and return a tensor representing a batch of phone posteriorgrams. """ # Extract feats # (batch, seq_len, num_feats) if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) assert cuts[ 0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}' otf = OnTheFlyFeatures(self.extractor) # feats: (batch, seq_len, n_feats) feats, _ = otf(cuts) # feats: (batch, n_feats, seq_len) feats = feats.permute(0, 2, 1) # Compute AM posteriors # posteriors: (batch, n_phones, ~seq_len / 4) posteriors, _, _ = self.model(feats) # returns: (batch, ~seq_len / 4, n_phones) return posteriors.permute(0, 2, 1) def decode( self, cuts: Union[AnyCut, CutSet]) -> List[Tuple[List[str], List[str]]]: """ Perform decoding with an n-gram language model (HLG graph). Doesn't support rescoring at this time. """ if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) word_results = [] # Hacky way to get batch quickly... we may need to improve on this. batch = K2SpeechRecognitionDataset(cuts, input_strategy=OnTheFlyFeatures( self.extractor), check_inputs=False)[list(cuts.ids)] features = batch['inputs'].permute(0, 2, 1).to( self.device) # (B, T, F) -> (B, F, T) supervision_segments, texts = encode_supervisions( batch['supervisions']) # Forward pass through the acoustic model posteriors, _, _ = self.model(features) posteriors = posteriors.permute(0, 2, 1) # (B, F, T) -> (B, T, F) # Wrapping into k2 "dense FSA" (representing PPG as a dense graph) dense_fsa_vec = k2.DenseFsaVec(posteriors, supervision_segments) # The actual decoding starts here: # First, we intersect the HLG and the PPG # with default pruning/beam search params from snowfall # The result is a batch of graphs (lattices) lattices = k2.intersect_dense_pruned(self.HLG, dense_fsa_vec, 20.0, 8, 30, 10000) # ... then we find the shortest paths in the lattices ... best_paths = k2.shortest_path(lattices, use_double_scores=True) # ... and convert them to words with a convenience wrapper from snowfall hyps = get_texts(best_paths, torch.arange(len(texts))) # Here we read out the words from the best path graphs for i in range(len(texts)): hyp_words = [self.lexicon.words.get(x) for x in hyps[i]] ref_words = texts[i].split(' ') word_results.append((ref_words, hyp_words)) return word_results def align(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor: """ Perform forced alignment and return a tensor that represents a batch of frame-level alignments: >>> alignments = torch.tensor([ ... [0, 0, 0, 1, 57, 57, 35, 35, 35, ...], ... [...], ... ... ... ]) :return: an int32 tensor with shape ``(batch_size, num_frames)``. """ # Extract feats # (batch, seq_len, num_feats) if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) assert cuts[ 0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}' cuts = cuts.map_supervisions(self.normalize_text) otf = OnTheFlyFeatures(self.extractor) feats, _ = otf(cuts) feats = feats.permute(0, 2, 1) texts = [' '.join(s.text for s in cut.supervisions) for cut in cuts] # Compute AM posteriors # (batch, seq_len ~/ 4, num_phones) posteriors, _, _ = self.model(feats) # Note: we are using "dummy" supervisions so that the aligner also considers # the padding area. We can adjust that behaviour if needed by passing actual # supervision segments, but then we will have a ragged tensor (will need to # pad the alignments themselves). sups = self.dummy_supervisions(feats) posteriors_fsa = k2.DenseFsaVec(posteriors.permute(0, 2, 1), sups) # Intersection with ground truth transcript graphs num, den = self.compiler.compile(texts, self.P) alignment = k2.intersect_dense(num, posteriors_fsa, output_beam=10.0) best_path = k2.shortest_path(alignment, use_double_scores=True) # Retrieve sequences of phone IDs per frame # (batch, seq_len ~/ 4) -- dtype int32 (num phone labels) frame_labels = torch.stack( [best_path[i].labels[:-1] for i in range(best_path.shape[0])]) return frame_labels def align_ctm(self, cuts: Union[CutSet, AnyCut]) -> List[List[AlignmentItem]]: """ Perform forced alignment and parse the phones into a CTM-like format: >>> [[0.0, 0.12, 'SIL'], [0.12, 0.2, 'AH0'], ...] """ # TODO: I am not sure that this method is extracting the alignment 100% correctly: # need to revise... # TODO: when K2/Snowfall has a standard way of indicating what is silence, # or we update the model, update the constants below. EPS = 0 SIL = 1 non_speech = {EPS, SIL} def to_s(n: int) -> float: FRAME_SHIFT = 0.04 # 0.01 * 4 subsampling return round(n * FRAME_SHIFT, ndigits=3) if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) # Uppercase and remove punctuation cuts = cuts.map_supervisions(self.normalize_text) alignments = self.align(cuts).tolist() ctm_alis = [] for cut, alignment in zip(cuts, alignments): # First we determine the silence regions at the beginning and the end: # we assume that every SIL and <eps> before the first phone, and after the last phone, # are representing silence. first_speech_idx = [ idx for idx, s in enumerate(alignment) if s not in non_speech ][0] last_speech_idx = [ idx for idx, s in reversed(list(enumerate(alignment))) if s not in non_speech ][0] speech_ali = alignment[first_speech_idx:last_speech_idx] ctm_ali = [ AlignmentItem(start=0.0, duration=to_s(first_speech_idx), symbol=self.lexicon.phones[SIL]) ] # Then, we iterate over the speech region: since the K2 model uses 2-state HMM # topology that allows blank (<eps>) to follow a phone symbol, we treat <eps> # as continuation of the "previous" phone. # TODO: I think this implementation is wrong in that it merges repeating phones... # Will fix. # TODO: I think it could be simplified by using some smart semi-ring and FSA operations... start = first_speech_idx prev_s = speech_ali[0] curr_s = speech_ali[0] cntr = 1 for s in speech_ali[1:]: curr_s = s if s != EPS else curr_s if curr_s != prev_s: ctm_ali.append( AlignmentItem(start=to_s(start), duration=to_s(cntr), symbol=self.lexicon.phones[prev_s])) start = start + cntr prev_s = curr_s cntr = 1 else: cntr += 1 if cntr: ctm_ali.append( AlignmentItem(start=to_s(start), duration=to_s(cntr), symbol=self.lexicon.phones[prev_s])) speech_end_timestamp = to_s(last_speech_idx) if speech_end_timestamp > cut.duration: logging.warning( f"speech_end_timestamp <= cut.duration. Skipping cut {cut.id}" ) ctm_alis.append(None) continue ctm_ali.append( AlignmentItem(start=speech_end_timestamp, duration=round(cut.duration - speech_end_timestamp, ndigits=8), symbol=self.lexicon.phones[SIL])) ctm_alis.append(ctm_ali) return ctm_alis def plot_alignments(self, cut: AnyCut): import matplotlib.pyplot as plt feats = self.compute_features(cut) phone_ids = self.align(cut) fig, axes = plt.subplots(2, squeeze=True, sharey=True, figsize=(10, 14)) axes[0].imshow(np.flipud(feats[0].T)) axes[1].imshow( torch.nn.functional.one_hot( phone_ids.repeat_interleave(4).to(torch.int64)).T) return fig, axes def plot_posteriors(self, cut: AnyCut): import matplotlib.pyplot as plt feats = self.compute_features(cut) posteriors = self.compute_posteriors(cut) fig, axes = plt.subplots(2, squeeze=True, sharey=True, figsize=(10, 14)) axes[0].imshow(np.flipud(feats[0].T)) axes[1].imshow(posteriors[0].exp().repeat_interleave(4, 1)) return fig, axes @staticmethod def dummy_supervisions(feats): def size_after_conv(size, num_layers=2): for i in range(num_layers): size = (size - 1) // 2 return size return torch.tensor([[ i, size_after_conv(2, num_layers=2), size_after_conv(feats.shape[2] - 2, num_layers=2) ] for i in range(feats.size(0))], dtype=torch.int32).clamp(min=0) @staticmethod def normalize_text(supervision): text = re.sub(r'[^\w\s]', '', supervision.text.upper()) return fastcopy(supervision, text=text)
def get_objf(batch: Dict, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, tb_writer: Optional[SummaryWriter] = None, global_batch_idx_train: Optional[int] = None, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] subsampling_factor = model.module.subsampling_factor if isinstance( model, DDP) else model.subsampling_factor supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], subsampling_factor), torch.floor_divide(supervisions['num_frames'], subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] if is_training: num, den = graph_compiler.compile(texts, P) else: with torch.no_grad(): num, den = graph_compiler.compile(texts, P) assert num.requires_grad == is_training assert den.requires_grad is False num = num.to(device) den = den.to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert nnet_output.device == device num = k2.intersect_dense(num, dense_fsa_vec, 10.0) den = k2.intersect_dense(den, dense_fsa_vec, 10.0) num_tot_scores = num.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if is_training: def maybe_log_gradients(tag: str): if (tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0): tb_writer.add_scalars(tag, measure_gradient_norms(model, norm='l1'), global_step=global_batch_idx_train) optimizer.zero_grad() (-tot_score).backward() maybe_log_gradients('train/grad_norms') clip_grad_value_(model.parameters(), 5.0) maybe_log_gradients('train/clipped_grad_norms') if tb_writer is not None and global_batch_idx_train % 200 == 0: # Once in a time we will perform a more costly diagnostic # to check the relative parameter change per minibatch. deltas = optim_step_and_measure_param_change(model, optimizer) tb_writer.add_scalars('train/relative_param_change_per_minibatch', deltas, global_step=global_batch_idx_train) else: optimizer.step() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def get_objf(batch: Dict, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] if is_training: num, den = graph_compiler.compile(texts, P) else: with torch.no_grad(): num, den = graph_compiler.compile(texts, P) assert num.requires_grad == is_training assert den.requires_grad is False num = num.to(device) den = den.to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert nnet_output.device == device num = k2.intersect_dense(num, dense_fsa_vec, 10.0) den = k2.intersect_dense(den, dense_fsa_vec, 10.0) num_tot_scores = num.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if is_training: optimizer.zero_grad() (-tot_score).backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans