def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable): tot_num_cuts = len(dataloader.dataset.cuts) num_cuts = 0 results = [] # a list of pair (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): feature = batch['inputs'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] # blank_bias = -3.0 # nnet_output[:, :, 0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) # assert LG.is_cuda() assert HLG.device == nnet_output.device, \ f"Check failed: LG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})" # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30, 10000) # lattices = k2.intersect_dense(LG, dense_fsa_vec, 10.0) best_paths = k2.shortest_path(lattices, use_double_scores=True) assert best_paths.shape[0] == len(texts) hyps = get_texts(best_paths, indices) assert len(hyps) == len(texts) for i in range(len(texts)): hyp_words = [symbols.get(x) for x in hyps[i]] ref_words = texts[i].split(' ') results.append((ref_words, hyp_words)) if batch_idx % 10 == 0: logging.info( 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( batch_idx, num_cuts, tot_num_cuts, float(num_cuts) / tot_num_cuts * 100)) num_cuts += len(texts) return results
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, HLG: Fsa, symbols: SymbolTable, num_paths: int, G: k2.Fsa, use_whole_lattice: bool, output_beam_size: float): num_cuts = 0 results = defaultdict(list) # results is a dict whose keys and values are: # - key: It indicates the lm_scale, e.g., lm_scale_1.2. # If no rescoring is used, the key is the literal string: no_rescore # # - value: It is a list of tuples (ref_words, hyp_words) num_batches = None try: num_batches = len(dataloader) except TypeError: pass for batch_idx, batch in enumerate(dataloader): # We remove the non-tensor valeus under key 'text' which enables this # to run with TorchScript models. texts = batch['supervisions'].pop('text') hyps_dict = decode_one_batch(batch=batch, model=model, HLG=HLG, output_beam_size=output_beam_size, num_paths=num_paths, use_whole_lattice=use_whole_lattice, G=G) for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) for i in range(len(texts)): hyp_words = [symbols.get(x) for x in hyps[i]] ref_words = texts[i].split(' ') this_batch.append((ref_words, hyp_words)) results[lm_scale].extend(this_batch) num_cuts += len(texts) if batch_idx % 10 == 0: batch_str = f"{batch_idx}" if num_batches is None else f"{batch_idx}/{num_batches}" logging.info( f"batch {batch_str}, number of cuts processed until now is {num_cuts}" ) return results
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], LG: Fsa, symbols: SymbolTable): results = [] # a list of pair (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) texts = supervisions['text'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert LG.is_cuda() assert LG.device == nnet_output.device, \ f"Check failed: LG.device ({LG.device}) == nnet_output.device ({nnet_output.device})" # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. lattices = k2.intersect_dense_pruned(LG, dense_fsa_vec, 2000.0, 20.0, 30, 300) best_paths = k2.shortest_path(lattices, use_float_scores=True) best_paths = best_paths.to('cpu') assert best_paths.shape[0] == len(texts) for i in range(len(texts)): hyp_words = [ symbols.get(x) for x in best_paths[i].aux_labels if x > 0 ] results.append((texts[i].split(' '), hyp_words)) if batch_idx % 10 == 0: logging.info('Processed batch {}/{} ({:.6f}%)'.format( batch_idx, len(dataloader), float(batch_idx) / len(dataloader) * 100)) return results
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, HLG: Fsa, symbols: SymbolTable, num_paths: int, G: k2.Fsa, use_whole_lattice: bool, output_beam_size: float): tot_num_cuts = len(dataloader.dataset.cuts) num_cuts = 0 results = defaultdict(list) # results is a dict whose keys and values are: # - key: It indicates the lm_scale, e.g., lm_scale_1.2. # If no rescoring is used, the key is the literal string: no_rescore # # - value: It is a list of tuples (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): texts = batch['supervisions']['text'] hyps_dict = decode_one_batch(batch=batch, model=model, HLG=HLG, output_beam_size=output_beam_size, num_paths=num_paths, use_whole_lattice=use_whole_lattice, G=G) for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) for i in range(len(texts)): hyp_words = [symbols.get(x) for x in hyps[i]] ref_words = texts[i].split(' ') this_batch.append((ref_words, hyp_words)) results[lm_scale].extend(this_batch) if batch_idx % 10 == 0: logging.info( 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( batch_idx, num_cuts, tot_num_cuts, float(num_cuts) / tot_num_cuts * 100)) num_cuts += len(texts) return results
def decode( dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable, ): num_batches = None try: num_batches = len(dataloader) except TypeError: pass num_cuts = 0 results = [] # a list of pair (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): feature = batch["inputs"] supervisions = batch["supervisions"] supervision_segments = torch.stack( ( supervisions["sequence_idx"], torch.floor_divide(supervisions["start_frame"], model.subsampling_factor), torch.floor_divide(supervisions["num_frames"], model.subsampling_factor), ), 1, ).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions["text"] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] blank_bias = -3.0 nnet_output[:, :, 0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) # assert HLG.is_cuda() assert ( HLG.device == nnet_output.device ), f"Check failed: HLG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})" # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30, 10000) # lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0) best_paths = k2.shortest_path(lattices, use_double_scores=True) assert best_paths.shape[0] == len(texts) hyps = get_texts(best_paths, indices) assert len(hyps) == len(texts) for i in range(len(texts)): hyp_words = [symbols.get(x) for x in hyps[i]] ref_words = texts[i].split(" ") results.append((ref_words, hyp_words)) if batch_idx % 10 == 0: batch_str = f"{batch_idx}" if num_batches is None else f"{batch_idx}/{num_batches}" logging.info( f"batch {batch_str}, number of cuts processed until now is {num_cuts}" ) num_cuts += len(texts) return results
def _ids_to_symbols(ids: List[int], symbol_table: k2.SymbolTable) -> List[str]: '''Convert a list of IDs to a list of symbols. ''' return [symbol_table.get(i) for i in ids]