def compile(self, texts: Iterable[str], P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa]: '''Create numerator and denominator graphs from transcripts and the bigram phone LM. Args: texts: A list of transcripts. Within a transcript, words are separated by spaces. P: The bigram phone LM created by :func:`create_bigram_phone_lm`. Returns: A tuple (num_graph, den_graph), where - `num_graph` is the numerator graph. It is an FsaVec with shape `(len(texts), None, None)`. - `den_graph` is the denominator graph. It is an FsaVec with the same shape of the `num_graph`. ''' assert P.is_cpu() ctc_topo_P = k2.intersect(self.ctc_topo, P).invert_() ctc_topo_P = k2.connect(ctc_topo_P) num_graphs = k2.create_fsa_vec( [self.compile_one_and_cache(text) for text in texts]) num = k2.compose(ctc_topo_P, num_graphs) num = k2.connect(num) num = k2.arc_sort(num) den = k2.create_fsa_vec([ctc_topo_P.detach()] * len(texts)) return num, den
def __init__(self, lexicon: Lexicon, P: k2.Fsa, device: torch.device, oov: str = '<UNK>'): ''' Args: L_inv: Its labels are words, while its aux_labels are phones. P: A phone bigram LM if the pronunciations in the lexicon are in phones; a word piece bigram if the pronunciations in the lexicon are word pieces. phones: The phone symbol table. words: The word symbol table. oov: Out of vocabulary word. ''' self.lexicon = lexicon L_inv = self.lexicon.L_inv.to(device) P = P.to(device) if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: L_inv = k2.arc_sort(L_inv) assert L_inv.requires_grad is False assert oov in self.lexicon.words self.L_inv = L_inv self.oov_id = self.lexicon.words[oov] self.oov = oov self.device = device phone_symbols = get_phone_symbols(self.lexicon.phones) phone_symbols_with_blank = [0] + phone_symbols ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device) assert ctc_topo.requires_grad is False ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) P_with_self_loops = k2.add_epsilon_self_loops(P) ctc_topo_P = k2.intersect(ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False).invert() self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], LG: Fsa, symbols: SymbolTable): results = [] # a list of pair (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) texts = supervisions['text'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert LG.is_cuda() assert LG.device == nnet_output.device, \ f"Check failed: LG.device ({LG.device}) == nnet_output.device ({nnet_output.device})" # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. lattices = k2.intersect_dense_pruned(LG, dense_fsa_vec, 2000.0, 20.0, 30, 300) best_paths = k2.shortest_path(lattices, use_float_scores=True) best_paths = best_paths.to('cpu') assert best_paths.shape[0] == len(texts) for i in range(len(texts)): hyp_words = [ symbols.get(x) for x in best_paths[i].aux_labels if x > 0 ] results.append((texts[i].split(' '), hyp_words)) if batch_idx % 10 == 0: logging.info('Processed batch {}/{} ({:.6f}%)'.format( batch_idx, len(dataloader), float(batch_idx) / len(dataloader) * 100)) return results
def compile_LG(L: Fsa, G: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L_inv = k2.arc_sort(L.invert_()) G = k2.arc_sort(G) logging.debug("Intersecting L and G") LG = k2.intersect(L_inv, G) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting L*G") LG = k2.connect(LG).invert_() logging.debug(f'LG shape = {LG.shape}') logging.debug("Determinizing L*G") LG = k2.determinize(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting det(L*G)") LG = k2.connect(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 LG = k2.add_epsilon_self_loops(LG) LG = k2.arc_sort(LG) logging.debug( f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) return LG
def format_output(self, num_frames: List[int]) -> Fsa: """ Generate the lattice Fsa currently got. Note: The attributes of the generated lattice is a union of the attributes of all the decoding graphs. For example, if `self` contains three individual stream, each stream has its own decoding graphs, graph[0] has attributes attr1, attr2; graph[1] has attributes attr1, attr3; graph[2] has attributes attr3, attr4; then the generated lattice has attributes attr1, attr2, attr3, attr4. Args: num_frames: A List containing the number of frames we want to gather for each stream (note: the frames we have ever received for the corresponding stream). It MUST satisfy `len(num_frames) == self.num_streams`. Returns: Return the lattice Fsa with all the attributes propagated. The returned Fsa has 3 axes with `fsa.dim0==self.num_streams`. """ assert len(num_frames) == self.num_streams ragged_arcs, out_map = self.streams.format_output(num_frames) fsa = Fsa(ragged_arcs) # propagate attributes tensor_attr_info = dict() # gather the attributes info of all the decoding graphs, for i in range(self.num_streams): src = self.src_streams[i].fsa for name, value in src.named_tensor_attr(include_scores=False): if name not in tensor_attr_info: filler = 0 if isinstance(value, Tensor): filler = float(src.get_filler(name)) dtype = value.dtype tensor_type = "Tensor" else: assert isinstance(value, k2.RaggedTensor) # Only integer types ragged attributes are supported now assert value.dtype == torch.int32 assert value.num_axes == 2 dtype = torch.int32 tensor_type = "RaggedTensor" tensor_attr_info[name] = { "filler": filler, "dtype": dtype, "tensor_type": tensor_type, } # combine the attributes propagating from different decoding graphs for name, info in tensor_attr_info.items(): values = list() start = 0 for i in range(self.num_streams): src = self.src_streams[i].fsa device = self.device num_arcs = fsa[i].num_arcs arc_map = out_map[start:start + num_arcs] start = start + num_arcs if hasattr(src, name): value = getattr(src, name) if info["tensor_type"] == "Tensor": assert isinstance(value, Tensor) new_value = index_select(value, arc_map, default_value=filler) else: assert isinstance(value, RaggedTensor) # Only integer types ragged attributes are supported now assert value.num_axes == 2 assert value.dtype == torch.int32 new_value, _ = value.index(arc_map, axis=0, need_value_indexes=False) else: if info["tensor_type"] == "Tensor": # fill with filler value new_value = torch.tensor( [filler] * num_arcs, dtype=info["dtype"], device=device, ) else: # fill with empty RaggedTensor new_value = RaggedTensor( torch.empty( (num_arcs, 0), dtype=info["dtype"], device=device, )) values.append(new_value) if info["tensor_type"] == "Tensor": new_value = torch.cat(values) else: new_value = k2.ragged.cat(values, axis=0) setattr(fsa, name, new_value) # set non_tensor_attrs for i in range(self.num_streams): src = self.src_streams[i].fsa for name, value in src.named_non_tensor_attr(): setattr(fsa, name, value) return fsa
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiMbrTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int): total_loss, total_mmi_loss, total_mbr_loss, total_frames, total_all_frames = 0., 0., 0., 0., 0. valid_average_loss = float('inf') time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() P.set_scores_stochastic_(model.P_scores) assert P.requires_grad is True curr_batch_mmi_loss, curr_batch_mbr_loss, curr_batch_frames, curr_batch_all_frames = get_loss( batch=batch, model=model, P=P, device=device, graph_compiler=graph_compiler, is_training=True, optimizer=optimizer) total_mmi_loss += curr_batch_mmi_loss total_mbr_loss += curr_batch_mbr_loss curr_batch_loss = curr_batch_mmi_loss + curr_batch_mbr_loss total_loss += curr_batch_loss total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info('batch {}, epoch {}/{} ' 'global average loss: {:.6f}, ' 'global average mmi loss: {:.6f}, ' 'global average mbr loss: {:.6f} over {} ' 'frames ({:.1f}% kept), ' 'current batch average loss: {:.6f}, ' 'current batch average mmi loss: {:.6f}, ' 'current batch average mbr loss: {:.6f} ' 'over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_loss / total_frames, total_mmi_loss / total_frames, total_mbr_loss / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_loss / (curr_batch_frames + 0.001), curr_batch_mmi_loss / (curr_batch_frames + 0.001), curr_batch_mbr_loss / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) tb_writer.add_scalar('train/global_average_loss', total_loss / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/global_average_mmi_loss', total_mmi_loss / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/global_average_mbr_loss', total_mbr_loss / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_loss', curr_batch_loss / (curr_batch_frames + 0.001), global_batch_idx_train) tb_writer.add_scalar( 'train/current_batch_average_mmi_loss', curr_batch_mmi_loss / (curr_batch_frames + 0.001), global_batch_idx_train) tb_writer.add_scalar( 'train/current_batch_average_mbr_loss', curr_batch_mbr_loss / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 3000 == 0: total_valid_loss, total_valid_mmi_loss, total_valid_mbr_loss, \ total_valid_frames, total_valid_all_frames = get_validation_loss( dataloader=valid_dataloader, model=model, P=P, device=device, graph_compiler=graph_compiler) valid_average_loss = total_valid_loss / total_valid_frames model.train() logging.info('Validation average loss: {:.6f}, ' 'Validation average mmi loss: {:.6f}, ' 'Validation average mbr loss: {:.6f} ' 'over {} frames ({:.1f}% kept)'.format( total_valid_loss / total_valid_frames, total_valid_mmi_loss / total_valid_frames, total_valid_mbr_loss / total_valid_frames, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) tb_writer.add_scalar('train/global_valid_average_loss', total_valid_loss / total_valid_frames, global_batch_idx_train) tb_writer.add_scalar('train/global_valid_average_mmi_loss', total_valid_mmi_loss / total_valid_frames, global_batch_idx_train) tb_writer.add_scalar('train/global_valid_average_mbr_loss', total_valid_mbr_loss / total_valid_frames, global_batch_idx_train) prev_timestamp = datetime.now() return total_loss / total_frames, valid_average_loss, global_batch_idx_train
def __init__(self, L_inv: k2.Fsa, L_disambig: k2.Fsa, G: k2.Fsa, phones: k2.SymbolTable, words: k2.SymbolTable, device: torch.device, oov: str = '<UNK>'): ''' Args: L_inv: Its labels are words, while its aux_labels are phones. L_disambig: L with disambig symbols. Its labels are phones and aux_labels are words. G: The language model. phones: The phone symbol table. words: The word symbol table. device: The target device that all FSAs should be moved to. oov: Out of vocabulary word. ''' L_inv = L_inv.to(device) G = G.to(device) if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: L_inv = k2.arc_sort(L_inv) if G.properties & k2.fsa_properties.ARC_SORTED != 0: G = k2.arc_sort(G) assert L_inv.requires_grad is False assert G.requires_grad is False assert oov in words L = L_inv.invert() L = k2.arc_sort(L) self.L_inv = L_inv self.L = L self.phones = phones self.words = words self.device = device self.oov_id = self.words[oov] phone_symbols = get_phone_symbols(phones) phone_symbols_with_blank = [0] + phone_symbols ctc_topo = k2.arc_sort( build_ctc_topo(phone_symbols_with_blank).to(device)) assert ctc_topo.requires_grad is False self.ctc_topo = ctc_topo self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert()) lang_dir = Path('data/lang_nosp') if not (lang_dir / 'HLG_uni.pt').exists(): logging.info("Composing (ctc_topo, L_disambig, G)") first_phone_disambig_id = find_first_disambig_symbol(phones) first_word_disambig_id = find_first_disambig_symbol(words) # decoding_graph is the result of composing (ctc_topo, L_disambig, G) decoding_graph = compile_HLG( L=L_disambig.to('cpu'), G=G.to('cpu'), H=ctc_topo.to('cpu'), labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(decoding_graph.as_dict(), lang_dir / 'HLG_uni.pt') else: logging.info("Loading pre-compiled HLG") decoding_graph = k2.Fsa.from_dict( torch.load(lang_dir / 'HLG_uni.pt')) assert hasattr(decoding_graph, 'phones') self.decoding_graph = decoding_graph.to(device)
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, ali_model: Optional[AcousticModel], P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, optimizer: torch.optim.Optimizer, accum_grad: int, den_scale: float, att_rate: float, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int, world_size: int, scaler: GradScaler): """One epoch training and validation. Args: dataloader: Training dataloader valid_dataloader: Validation dataloader model: Acoustic model to be trained P: An FSA representing the bigram phone LM device: Training device, torch.device("cpu") or torch.device("cuda", device_id) graph_compiler: MMI training graph compiler optimizer: Training optimizer accum_grad: Number of gradient accumulation den_scale: Denominator scale in mmi loss att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss current_epoch: current training epoch, for logging only tb_writer: tensorboard SummaryWriter num_epochs: total number of training epochs, for logging only global_batch_idx_train: global training batch index before this epoch, for logging only Returns: A tuple of 3 scalar: (total_objf / total_frames, valid_average_objf, global_batch_idx_train) - `total_objf / total_frames` is the average training loss - `valid_average_objf` is the average validation loss - `global_batch_idx_train` is the global training batch index after this epoch """ total_objf, total_frames, total_all_frames = 0., 0., 0. valid_average_objf = float('inf') time_waiting_for_batch = 0 forward_count = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): forward_count += 1 if forward_count == accum_grad: is_update = True forward_count = 0 else: is_update = False global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() if forward_count == 1 or accum_grad == 1: P.set_scores_stochastic_(model.module.P_scores) assert P.requires_grad is True curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch=batch, model=model, ali_model=ali_model, P=P, device=device, graph_compiler=graph_compiler, is_training=True, is_update=is_update, accum_grad=accum_grad, den_scale=den_scale, att_rate=att_rate, tb_writer=tb_writer, global_batch_idx_train=global_batch_idx_train, optimizer=optimizer, scaler=scaler) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) if tb_writer is not None: tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar( 'train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, ali_model=ali_model, P=P, device=device, graph_compiler=graph_compiler, scaler=scaler) if world_size > 1: s = torch.tensor([ total_valid_objf, total_valid_frames, total_valid_all_frames ]).to(device) dist.all_reduce(s, op=dist.ReduceOp.SUM) total_valid_objf, total_valid_frames, total_valid_all_frames = s.cpu( ).tolist() valid_average_objf = total_valid_objf / total_valid_frames model.train() logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format(valid_average_objf, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) if tb_writer is not None: tb_writer.add_scalar('train/global_valid_average_objf', valid_average_objf, global_batch_idx_train) model.module.write_tensorboard_diagnostics( tb_writer, global_step=global_batch_idx_train) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa) -> k2.Fsa: '''Use whole lattice to rescore. Args: lats: An FsaVec It can be the output of `k2.intersect_dense_pruned`. G_with_epsilon_loops: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. ''' assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') assert G_with_epsilon_loops.shape == (1, None, None) device = lats.device lats.scores = lats.scores - lats.lm_scores # Now, lats.scores contains only am_scores # inverted_lats has word IDs as labels. # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt inverted_lats = k2.invert(lats) num_seqs = lats.shape[0] inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(inverted_lats) b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) try: rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats_with_epsilon_loops, b_to_a_map, sorted_match_a=True) except RuntimeError as e: print(f'Caught exception:\n{e}\n') print(f'Number of FSAs: {inverted_lats.shape[0]}') print('num_arcs before pruning: ', inverted_lats_with_epsilon_loops.arcs.num_elements()) # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here # to avoid OOM. We may need to fine tune it. inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True) inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops( inverted_lats) print('num_arcs after pruning: ', inverted_lats_with_epsilon_loops.arcs.num_elements()) rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats_with_epsilon_loops, b_to_a_map, sorted_match_a=True) rescoring_lats = k2.top_sort(k2.connect( rescoring_lats.to('cpu'))).to(device) inverted_rescoring_lats = k2.invert(rescoring_lats) # inverted rescoring_lats has phone IDs as labels # and word IDs as aux_labels. inverted_rescoring_lats = k2.remove_epsilon_self_loops( inverted_rescoring_lats) best_paths = k2.shortest_path(inverted_rescoring_lats, use_double_scores=True) return best_paths
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: List[float] ) -> Dict[str, k2.Fsa]: '''Use whole lattice to rescore. Args: lats: An FsaVec It can be the output of `k2.intersect_dense_pruned`. G_with_epsilon_loops: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. lm_scale_list: A list containing lm_scale values. Returns: A dict of FsaVec, whose key is a lm_scale and the value represents the best decoding path for each sequence in the lattice. ''' assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') assert G_with_epsilon_loops.shape == (1, None, None) device = lats.device lats.scores = lats.scores - lats.lm_scores # We will use lm_scores from G, so remove lats.lm_scores here del lats.lm_scores assert hasattr(lats, 'lm_scores') is False # lats.scores = scores / lm_scale # Now, lats.scores contains only am_scores # inverted_lats has word IDs as labels. # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt inverted_lats = k2.invert(lats) num_seqs = lats.shape[0] b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) try: rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats, b_to_a_map, sorted_match_a=True) except RuntimeError as e: print(f'Caught exception:\n{e}\n') print(f'Number of FSAs: {inverted_lats.shape[0]}') print('num_arcs before pruning: ', inverted_lats.arcs.num_elements()) # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here # to avoid OOM. We may need to fine tune it. inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True) print('num_arcs after pruning: ', inverted_lats.arcs.num_elements()) rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats, b_to_a_map, sorted_match_a=True) rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device)) # inv_lats has phone IDs as labels # and word IDs as aux_labels. inv_lats = k2.invert(rescoring_lats) ans = dict() # # The following implements # scores = (scores - lm_scores)/lm_scale + lm_scores # = scores/lm_scale + lm_scores*(1 - 1/lm_scale) # saved_scores = inv_lats.scores.clone() for lm_scale in lm_scale_list: am_scores = saved_scores - inv_lats.lm_scores am_scores /= lm_scale inv_lats.scores = am_scores + inv_lats.lm_scores best_paths = k2.shortest_path(inv_lats, use_double_scores=True) key = f'lm_scale_{lm_scale}' ans[key] = best_paths return ans
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: Optional[SummaryWriter], num_epochs: int, global_batch_idx_train: int): total_objf, total_frames, total_all_frames = 0., 0., 0. valid_average_objf = float('inf') time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() if isinstance(model, DDP): P.set_scores_stochastic_(model.module.P_scores) else: P.set_scores_stochastic_(model.P_scores) assert P.is_cpu assert P.requires_grad is True curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch=batch, model=model, P=P, device=device, graph_compiler=graph_compiler, is_training=True, tb_writer=tb_writer, global_batch_idx_train=global_batch_idx_train, optimizer=optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0 and dist.get_rank() == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 1000 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, P=P, device=device, graph_compiler=graph_compiler) # Synchronize the loss to the master node so that we display it correctly. # dist.reduce performs sum reduction by default. valid_average_objf = total_valid_objf / total_valid_frames model.train() if dist.get_rank() == 0: logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format( valid_average_objf, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) if tb_writer is not None: tb_writer.add_scalar('train/global_valid_average_objf', valid_average_objf, global_batch_idx_train) (model.module if isinstance(model, DDP) else model).write_tensorboard_diagnostics( tb_writer, global_step=global_batch_idx_train) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train
def compile_HLG(L: Fsa, G: Fsa, H: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. H: An ``Fsa`` that represents a specific topology used to convert the network outputs to a sequence of phones. Typically, it's a CTC topology fst, in which when 0 appears on the left side, it represents the blank symbol; when it appears on the right side, it indicates an epsilon. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L = k2.arc_sort(L) G = k2.arc_sort(G) # Attach a new attribute `lm_scores` so that we can recover # the `am_scores` later. # The scores on an arc consists of two parts: # scores = am_scores + lm_scores # NOTE: we assume that both kinds of scores are in log-space. G.lm_scores = G.scores.clone() logging.info("Intersecting L and G") LG = k2.compose(L, G) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting L*G") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Determinizing L*G") LG = k2.determinize(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting det(L*G)") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 if isinstance(LG.aux_labels, torch.Tensor): LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 else: LG.aux_labels.values()[ LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0 logging.info("Removing epsilons") LG = k2.remove_epsilon(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting rm-eps(det(L*G))") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) logging.info("Composing ctc_topo LG") HLG = k2.compose(H, LG, inner_labels='phones') logging.info("Connecting LG") HLG = k2.connect(HLG) logging.info("Arc sorting LG") HLG = k2.arc_sort(HLG) logging.info( f'LG is arc sorted: {(HLG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) return HLG
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int, global_batch_idx_valid: int): total_objf, total_frames, total_all_frames = 0., 0., 0. time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() ragged_shape = P.arcs.shape().to(device) for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() P.set_scores_stochastic_(model.P_scores) assert P.is_cpu assert P.requires_grad is True curr_batch_objf, curr_batch_frames, curr_batch_all_frames = \ get_objf(batch, model, P, device, graph_compiler, True, optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, P=P, device=device, graph_compiler=graph_compiler) global_batch_idx_valid += 1 model.train() logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format(total_valid_objf / total_valid_frames, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) tb_writer.add_scalar('train/global_valid_average_objf', total_valid_objf / total_valid_frames, global_batch_idx_valid) prev_timestamp = datetime.now() return total_objf / total_frames
def compile_LG(L: Fsa, G: Fsa, ctc_topo_inv: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. ctc_topo_inv: Epsilons are in `aux_labels` and `labels` contain phone IDs. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L_inv = k2.arc_sort(L.invert_()) G = k2.arc_sort(G) logging.debug("Intersecting L and G") LG = k2.intersect(L_inv, G) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting L*G") LG = k2.connect(LG).invert_() logging.debug(f'LG shape = {LG.shape}') logging.debug("Determinizing L*G") LG = k2.determinize(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting det(L*G)") LG = k2.connect(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 if isinstance(LG.aux_labels, torch.Tensor): LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 else: LG.aux_labels.values()[ LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0 logging.debug("Removing epsilons") LG = k2.remove_epsilons_iterative_tropical(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting rm-eps(det(L*G))") LG = k2.connect(LG) logging.debug(f'LG shape = {LG.shape}') LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) logging.debug("Arc sorting") LG = k2.arc_sort(LG) logging.debug("Composing") LG = k2.compose(ctc_topo_inv, LG) logging.debug("Connecting") LG = k2.connect(LG) logging.debug("Arc sorting") LG = k2.arc_sort(LG) logging.debug( f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) return LG