def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() model_type = args.model_type epoch = args.epoch avg = args.avg att_rate = args.att_rate num_paths = args.num_paths use_lm_rescoring = args.use_lm_rescoring use_whole_lattice = False if use_lm_rescoring and num_paths < 1: # It doesn't make sense to use n-best list for rescoring # when n is less than 1 use_whole_lattice = True output_beam_size = args.output_beam_size exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') logging.info(f'output_beam_size: {output_beam_size}') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') phone_ids = get_phone_symbols(phone_symbol_table) P = create_bigram_phone_lm(phone_ids) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) logging.debug("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=True) elif model_type == "conformer": model = Conformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=True) elif model_type == "contextnet": model = ContextNet(num_features=80, num_classes=len(phone_ids) + 1) # +1 for the blank symbol else: raise NotImplementedError("Model of type " + str(model_type) + " is not implemented") model.P_scores = torch.nn.Parameter(P.scores.clone(), requires_grad=False) if avg == 1: checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt') load_checkpoint(checkpoint, model) else: checkpoints = [ os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt') for avg_epoch in range(epoch - avg, epoch) ] average_checkpoint(checkpoints, model) model.to(device) model.eval() assert P.requires_grad is False P.scores = model.P_scores.cpu() print_transition_probabilities(P, phone_symbol_table, phone_ids, filename='model_P_scores.txt') P.set_scores_stochastic_(model.P_scores) print_transition_probabilities(P, phone_symbol_table, phone_ids, filename='P_scores.txt') if not os.path.exists(lang_dir / 'HLG.pt'): logging.debug("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) logging.debug("Loading G.fst.txt") with open(lang_dir / 'G.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) first_word_disambig_id = find_first_disambig_symbol(symbol_table) HLG = compile_HLG(L=L, G=G, H=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') else: logging.debug("Loading pre-compiled HLG") d = torch.load(lang_dir / 'HLG.pt') HLG = k2.Fsa.from_dict(d) if use_lm_rescoring: if use_whole_lattice: logging.info('Rescoring with the whole lattice') else: logging.info(f'Rescoring with n-best list, n is {num_paths}') first_word_disambig_id = find_first_disambig_symbol(symbol_table) if not os.path.exists(lang_dir / 'G_4_gram.pt'): logging.debug('Loading G_4_gram.fst.txt') with open(lang_dir / 'G_4_gram.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) # G.aux_labels is not needed in later computations, so # remove it here. del G.aux_labels # CAUTION(fangjun): The following line is crucial. # Arcs entering the back-off state have label equal to #0. # We have to change it to 0 here. G.labels[G.labels >= first_word_disambig_id] = 0 G = k2.create_fsa_vec([G]).to(device) G = k2.arc_sort(G) torch.save(G.as_dict(), lang_dir / 'G_4_gram.pt') else: logging.debug('Loading pre-compiled G_4_gram.pt') d = torch.load(lang_dir / 'G_4_gram.pt') G = k2.Fsa.from_dict(d).to(device) if use_whole_lattice: # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = k2.add_epsilon_self_loops(G) G = k2.arc_sort(G) G = G.to(device) else: logging.debug('Decoding without LM rescoring') G = None logging.debug("convert HLG to device") HLG = HLG.to(device) HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) HLG.requires_grad_(False) if not hasattr(HLG, 'lm_scores'): HLG.lm_scores = HLG.scores.clone() # load dataset librispeech = LibriSpeechAsrDataModule(args) test_sets = ['test-clean', 'test-other'] # test_sets = ['test-other'] for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): logging.info(f'* DECODING: {test_set}') results = decode(dataloader=test_dl, model=model, device=device, HLG=HLG, symbols=symbol_table, num_paths=num_paths, G=G, use_whole_lattice=use_whole_lattice, output_beam_size=output_beam_size) recog_path = exp_dir / f'recogs-{test_set}.txt' store_transcripts(path=recog_path, texts=results) logging.info(f'The transcripts are stored in {recog_path}') # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = exp_dir / f'errs-{test_set}.txt' with open(errs_filename, 'w') as f: write_error_stats(f, test_set, results) logging.info('Wrote detailed error stats to {}'.format(errs_filename))
def main(): parser = get_parser() GigaSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() model_type = args.model_type epoch = args.epoch avg = args.avg att_rate = args.att_rate num_paths = args.num_paths use_lm_rescoring = args.use_lm_rescoring use_whole_lattice = False if use_lm_rescoring and num_paths < 1: # It doesn't make sense to use n-best list for rescoring # when n is less than 1 use_whole_lattice = True output_beam_size = args.output_beam_size suffix = '' if args.context_window is not None and args.context_window > 0: suffix = f'ac{args.context_window}' giga_subset = f'giga{args.subset}' exp_dir = Path( f'exp-{model_type}-mmi-att-sa-vgg-normlayer-{giga_subset}-{suffix}') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') logging.info(f'output_beam_size: {output_beam_size}') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') phone_ids = get_phone_symbols(phone_symbol_table) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) logging.debug("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=args.vgg_fronted) elif model_type == "conformer": model = Conformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=args.vgg_frontend, is_espnet_structure=args.is_espnet_structure) elif model_type == "contextnet": model = ContextNet(num_features=80, num_classes=len(phone_ids) + 1) # +1 for the blank symbol else: raise NotImplementedError("Model of type " + str(model_type) + " is not implemented") if avg == 1: checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt') load_checkpoint(checkpoint, model) else: checkpoints = [ os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt') for avg_epoch in range(epoch - avg, epoch) ] average_checkpoint(checkpoints, model) if args.torchscript: logging.info('Applying TorchScript to model...') model = torch.jit.script(model) ts_path = exp_dir / f'model_ts_epoch{epoch}_avg{avg}.pt' logging.info(f'Storing the TorchScripted model in {ts_path}') model.save(ts_path) model.to(device) model.eval() if not os.path.exists(lang_dir / 'HLG.pt'): logging.debug("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) logging.debug("Loading G.fst.txt") with open(lang_dir / 'G.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) first_word_disambig_id = find_first_disambig_symbol(symbol_table) HLG = compile_HLG(L=L, G=G, H=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') else: logging.debug("Loading pre-compiled HLG") d = torch.load(lang_dir / 'HLG.pt') HLG = k2.Fsa.from_dict(d) if use_lm_rescoring: if use_whole_lattice: logging.info('Rescoring with the whole lattice') else: logging.info(f'Rescoring with n-best list, n is {num_paths}') first_word_disambig_id = find_first_disambig_symbol(symbol_table) if not os.path.exists(lang_dir / 'G_4_gram.pt'): logging.debug('Loading G_4_gram.fst.txt') with open(lang_dir / 'G_4_gram.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) # G.aux_labels is not needed in later computations, so # remove it here. del G.aux_labels # CAUTION(fangjun): The following line is crucial. # Arcs entering the back-off state have label equal to #0. # We have to change it to 0 here. G.labels[G.labels >= first_word_disambig_id] = 0 G = k2.create_fsa_vec([G]).to(device) G = k2.arc_sort(G) torch.save(G.as_dict(), lang_dir / 'G_4_gram.pt') else: logging.debug('Loading pre-compiled G_4_gram.pt') d = torch.load(lang_dir / 'G_4_gram.pt') G = k2.Fsa.from_dict(d).to(device) if use_whole_lattice: # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = k2.add_epsilon_self_loops(G) G = k2.arc_sort(G) G = G.to(device) # G.lm_scores is used to replace HLG.lm_scores during # LM rescoring. G.lm_scores = G.scores.clone() else: logging.debug('Decoding without LM rescoring') G = None if num_paths > 1: logging.debug(f'Use n-best list decoding, n is {num_paths}') else: logging.debug('Use 1-best decoding') logging.debug("convert HLG to device") HLG = HLG.to(device) HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) HLG.requires_grad_(False) if not hasattr(HLG, 'lm_scores'): HLG.lm_scores = HLG.scores.clone() # load dataset gigaspeech = GigaSpeechAsrDataModule(args) test_sets = ['DEV', 'TEST'] for test_set, test_dl in zip( test_sets, [gigaspeech.valid_dataloaders(), gigaspeech.test_dataloaders()]): logging.info(f'* DECODING: {test_set}') test_set_wers = dict() results_dict = decode(dataloader=test_dl, model=model, HLG=HLG, symbols=symbol_table, num_paths=num_paths, G=G, use_whole_lattice=use_whole_lattice, output_beam_size=output_beam_size) for key, results in results_dict.items(): recog_path = exp_dir / f'recogs-{test_set}-{key}.txt' store_transcripts(path=recog_path, texts=results) logging.info(f'The transcripts are stored in {recog_path}') ref_path = exp_dir / f'ref-{test_set}.trn' hyp_path = exp_dir / f'hyp-{test_set}.trn' store_transcripts_for_sclite(ref_path=ref_path, hyp_path=hyp_path, texts=results) logging.info( f'The sclite-format transcripts are stored in {ref_path} and {hyp_path}' ) cmd = f'python3 GigaSpeech/utils/gigaspeech_scoring.py {ref_path} {hyp_path} {exp_dir / "tmp_sclite"}' logging.info(cmd) try: subprocess.run(cmd, check=True, shell=True) except subprocess.CalledProcessError: logging.error( 'Skipping sclite scoring as it failed to run: Is "sclite" registered in your $PATH?"' ) # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = exp_dir / f'errs-{test_set}-{key}.txt' with open(errs_filename, 'w') as f: wer = write_error_stats(f, f'{test_set}-{key}', results) test_set_wers[key] = wer logging.info( 'Wrote detailed error stats to {}'.format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = exp_dir / f'wer-summary-{test_set}.txt' with open(errs_info, 'w') as f: print('settings\tWER', file=f) for key, val in test_set_wers: print('{}\t{}'.format(key, val), file=f) s = '\nFor {}, WER of different settings are:\n'.format(test_set) note = '\tbest for {}'.format(test_set) for key, val in test_set_wers: s += '{}\t{}{}\n'.format(key, val, note) note = '' logging.info(s)
def main(): args = get_parser().parse_args() exp_dir = Path('exp-lstm-adam-mmi-bigram-musan-dist') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') lexicon = Lexicon(lang_dir) phone_ids = lexicon.phone_symbols() phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) logging.debug("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') model = TdnnLstm1b( num_features=80, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) checkpoint = os.path.join(exp_dir, f'epoch-{args.epoch}.pt') load_checkpoint(checkpoint, model) model.to(device) model.eval() if not os.path.exists(lang_dir / 'HLG.pt'): logging.debug("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) logging.debug("Loading G.fst.txt") with open(lang_dir / 'G.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol(lexicon.phones) first_word_disambig_id = find_first_disambig_symbol(lexicon.words) HLG = compile_HLG(L=L, G=G, H=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') else: logging.debug("Loading pre-compiled LG") d = torch.load(lang_dir / 'HLG.pt') HLG = k2.Fsa.from_dict(d) # load dataset feature_dir = Path('exp/data') logging.debug("About to get test cuts") cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz') logging.info("About to create test dataset") test = K2SpeechRecognitionDataset(cuts_test) sampler = SingleCutSampler(cuts_test, max_frames=40000) logging.info("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) # if not torch.cuda.is_available(): # logging.error('No GPU detected!') # sys.exit(-1) logging.debug("convert HLG to device") HLG = HLG.to(device) HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) HLG.requires_grad_(False) logging.debug("About to decode") results = decode(dataloader=test_dl, model=model, device=device, HLG=HLG, symbols=lexicon.words) test_set = 'test-clean' recog_path = exp_dir / f'recogs-{test_set}.txt' store_transcripts(path=recog_path, texts=results) logging.info(f'The transcripts are stored in {recog_path}') # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = exp_dir / f'errs-{test_set}.txt' with open(errs_filename, 'w') as f: wer = write_error_stats(f, f'{test_set}', results) logging.info(f'The error stats are stored in {errs_filename}')
def main(): args = get_parser().parse_args() model_type = args.model_type epoch = args.epoch max_duration = args.max_duration avg = args.avg att_rate = args.att_rate exp_dir = Path('exp-' + model_type + '-noam-ctc-att-musan-sa') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') phone_ids = get_phone_symbols(phone_symbol_table) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) logging.debug("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 if model_type == "transformer": model = Transformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) else: model = Conformer( num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, num_decoder_layers=num_decoder_layers) if avg == 1: checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt') load_checkpoint(checkpoint, model) else: checkpoints = [ os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt') for avg_epoch in range(epoch - avg, epoch) ] average_checkpoint(checkpoints, model) model.to(device) model.eval() if not os.path.exists(lang_dir / 'HLG.pt'): logging.debug("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) logging.debug("Loading G.fst.txt") with open(lang_dir / 'G.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) first_word_disambig_id = find_first_disambig_symbol(symbol_table) HLG = compile_HLG(L=L, G=G, H=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') else: logging.debug("Loading pre-compiled HLG") d = torch.load(lang_dir / 'HLG.pt') HLG = k2.Fsa.from_dict(d) logging.debug("convert HLG to device") HLG = HLG.to(device) HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) HLG.requires_grad_(False) # load dataset feature_dir = Path('exp/data') test_sets = ['test-clean', 'test-other'] for test_set in test_sets: logging.info(f'* DECODING: {test_set}') logging.debug("About to get test cuts") cuts_test = load_manifest(feature_dir / f'cuts_{test_set}.json.gz') logging.debug("About to create test dataset") from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse import Fbank, FbankConfig test = K2SpeechRecognitionDataset( cuts_test, input_strategy=OnTheFlyFeatures(Fbank( FbankConfig(num_mel_bins=80)))) sampler = SingleCutSampler(cuts_test, max_duration=max_duration) logging.debug("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) logging.debug("About to decode") results = decode(dataloader=test_dl, model=model, device=device, HLG=HLG, symbols=symbol_table) recog_path = exp_dir / f'recogs-{test_set}.txt' store_transcripts(path=recog_path, texts=results) logging.info(f'The transcripts are stored in {recog_path}') # compute WER dists = [edit_distance(r, h) for r, h in results] errors = { key: sum(dist[key] for dist in dists) for key in ['sub', 'ins', 'del', 'total'] } total_words = sum(len(ref) for ref, _ in results) # Print Kaldi-like message: # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ] logging.info( f'[{test_set}] %WER {errors["total"] / total_words:.2%} ' f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' )
def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) logging.basicConfig(level=logging.DEBUG) args = parser.parse_args() avg = args.avg attention_dim = args.attention_dim nhead = args.nhead att_rate = args.att_rate model_type = args.model_type epoch = args.epoch # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') lang_dir = Path('data/en_token_list/bpe_unigram5000/') bpe_model_path = lang_dir / 'bpe.model' tokens_file = lang_dir / 'tokens.txt' numericalizer = Numericalizer.build_numericalizer(bpe_model_path, tokens_file) if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 num_classes = len(numericalizer.tokens_list) if model_type == "conformer": model = Conformer(num_features=80, nhead=args.nhead, d_model=args.attention_dim, num_classes=num_classes, subsampling_factor=4, num_decoder_layers=num_decoder_layers, vgg_frontend=args.vgg_frontend, is_espnet_structure=args.is_espnet_structure, mmi_loss=False) if args.espnet_identical_model: assert sum([p.numel() for p in model.parameters()]) == 116146960 else: raise NotImplementedError("Model of type " + str(model_type) + " is not verified") exp_dir = Path(f'exp-bpe-{model_type}-{attention_dim}-{nhead}-noam/') if args.decode_with_released_model is True: released_model_path = exp_dir / f'model-epoch-{epoch}-avg-{avg}.pt' model.load_state_dict(torch.load(released_model_path)) else: if avg == 1: checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt') load_checkpoint(checkpoint, model) else: checkpoints = [ os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt') for avg_epoch in range(epoch - avg, epoch) ] average_checkpoint(checkpoints, model) if args.generate_release_model: released_model_path = exp_dir / f'model-epoch-{epoch}-avg-{avg}.pt' torch.save(model.state_dict(), released_model_path) model.to(device) model.eval() token_ids_with_blank = [i for i in range(num_classes)] ctc_path = lang_dir / 'ctc_topo.pt' if not os.path.exists(ctc_path): logging.info("Generating ctc topo...") ctc_topo = k2.arc_sort(build_ctc_topo(token_ids_with_blank)) torch.save(ctc_topo.as_dict(), ctc_path) else: logging.info("Loading pre-compiled ctc topo fst") d_ctc_topo = torch.load(ctc_path) ctc_topo = k2.Fsa.from_dict(d_ctc_topo) ctc_topo = ctc_topo.to(device) feature_dir = Path('exp/data') librispeech = LibriSpeechAsrDataModule(args) test_sets = ['test-clean', 'test-other'] for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): results = decode(dataloader=test_dl, model=model, device=device, ctc_topo=ctc_topo, numericalizer=numericalizer, num_paths=args.num_paths, output_beam_size=args.output_beam_size) recog_path = exp_dir / f'recogs-{test_set}.txt' store_transcripts(path=recog_path, texts=results) logging.info(f'The transcripts are stored in {recog_path}') # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = exp_dir / f'errs-{test_set}.txt' with open(errs_filename, 'w') as f: write_error_stats(f, test_set, results) logging.info('Wrote detailed error stats to {}'.format(errs_filename))