def __init__(self, L_inv: k2.Fsa, phones: k2.SymbolTable, words: k2.SymbolTable, oov: str = '<UNK>'): ''' Args: L_inv: Its labels are words, while its aux_labels are phones. phones: The phone symbol table. words: The word symbol table. oov: Out of vocabulary word. ''' if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: L_inv = k2.arc_sort(L_inv) assert oov in words self.L_inv = L_inv self.phones = phones self.words = words self.oov = oov phone_ids = get_phone_symbols(phones) phone_ids_with_blank = [0] + phone_ids self.ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))
def __init__(self, lexicon: Lexicon, P: k2.Fsa, device: torch.device, oov: str = '<UNK>'): ''' Args: L_inv: Its labels are words, while its aux_labels are phones. P: A phone bigram LM if the pronunciations in the lexicon are in phones; a word piece bigram if the pronunciations in the lexicon are word pieces. phones: The phone symbol table. words: The word symbol table. oov: Out of vocabulary word. ''' self.lexicon = lexicon L_inv = self.lexicon.L_inv.to(device) P = P.to(device) if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: L_inv = k2.arc_sort(L_inv) assert L_inv.requires_grad is False assert oov in self.lexicon.words self.L_inv = L_inv self.oov_id = self.lexicon.words[oov] self.oov = oov self.device = device phone_symbols = get_phone_symbols(self.lexicon.phones) phone_symbols_with_blank = [0] + phone_symbols ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device) assert ctc_topo.requires_grad is False ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) P_with_self_loops = k2.add_epsilon_self_loops(P) ctc_topo_P = k2.intersect(ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False).invert() self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
def __init__(self, lexicon: Lexicon, device: torch.device, oov: str = '<UNK>'): ''' Args: L_inv: Its labels are words, while its aux_labels are phones. phones: The phone symbol table. words: The word symbol table. oov: Out of vocabulary word. ''' self.lexicon = lexicon L_inv = self.lexicon.L_inv.to(device) if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: L_inv = k2.arc_sort(L_inv) assert L_inv.requires_grad is False assert oov in self.lexicon.words self.L_inv = L_inv self.oov_id = self.lexicon.words[oov] self.oov = oov self.device = device phone_symbols = get_phone_symbols(self.lexicon.phones) phone_symbols_with_blank = [0] + phone_symbols ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device) assert ctc_topo.requires_grad is False self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
def main(): fix_random_seed(42) start_epoch = 0 num_epochs = 8 exp_dir = 'exp-lstm-adam-ctc-musan' setup_logger('{}/log/log-train'.format(exp_dir)) tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') logging.info("Loading L.fst") if (lang_dir / 'Linv.pt').exists(): L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt')) else: with open(lang_dir / 'L.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) L_inv = k2.arc_sort(L.invert_()) torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt') graph_compiler = CtcTrainingGraphCompiler( L_inv=L_inv, phones=phone_symbol_table, words=word_symbol_table ) phone_ids = get_phone_symbols(phone_symbol_table) # load dataset feature_dir = Path('exp/data') logging.info("About to get train cuts") cuts_train = CutSet.from_json(feature_dir / 'cuts_train-clean-100.json.gz') logging.info("About to get dev cuts") cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev-clean.json.gz') logging.info("About to get Musan cuts") cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') logging.info("About to create train dataset") train = K2SpeechRecognitionDataset( cuts_train, cut_transforms=[ CutConcatenate(), CutMix( cuts=cuts_musan, prob=0.5, snr=(10, 20) ) ] ) train_sampler = SingleCutSampler( cuts_train, max_frames=90000, shuffle=True, ) logging.info("About to create train dataloader") train_dl = torch.utils.data.DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=4 ) logging.info("About to create dev dataset") validate = K2SpeechRecognitionDataset(cuts_dev) valid_sampler = SingleCutSampler(cuts_dev, max_frames=90000) logging.info("About to create dev dataloader") valid_dl = torch.utils.data.DataLoader( validate, sampler=valid_sampler, batch_size=None, num_workers=1 ) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) logging.info("About to create model") device_id = 0 device = torch.device('cuda', device_id) model = TdnnLstm1b( num_features=40, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=3) model.to(device) describe(model) learning_rate = 1e-3 optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=5e-4) best_objf = np.inf best_valid_objf = np.inf best_epoch = start_epoch best_model_path = os.path.join(exp_dir, 'best_model.pt') best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') global_batch_idx_train = 0 # for logging only if start_epoch > 0: model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer) best_objf = ckpt['objf'] best_valid_objf = ckpt['valid_objf'] global_batch_idx_train = ckpt['global_batch_idx_train'] logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}") for epoch in range(start_epoch, num_epochs): train_sampler.set_epoch(epoch) curr_learning_rate = 1e-3 # curr_learning_rate = learning_rate * pow(0.4, epoch) # for param_group in optimizer.param_groups: # param_group['lr'] = curr_learning_rate tb_writer.add_scalar('learning_rate', curr_learning_rate, epoch) logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf, valid_objf, global_batch_idx_train = train_one_epoch(dataloader=train_dl, valid_dataloader=valid_dl, model=model, device=device, graph_compiler=graph_compiler, optimizer=optimizer, current_epoch=epoch, tb_writer=tb_writer, num_epochs=num_epochs, global_batch_idx_train=global_batch_idx_train) # the lower, the better if valid_objf < best_valid_objf: best_valid_objf = valid_objf best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, model=model, epoch=epoch, optimizer=None, scheduler=None, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=best_objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=None, epoch=epoch, learning_rate=curr_learning_rate, objf=objf, valid_objf=valid_objf, global_batch_idx_train=global_batch_idx_train) epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, valid_objf=valid_objf, best_valid_objf=best_valid_objf, best_epoch=best_epoch) logging.warning('Done')
def main(): exp_dir = Path("exp-lstm-adam-ctc-musan") setup_logger("{}/log/log-decode".format(exp_dir), log_level="debug") # load L, G, symbol_table lang_dir = Path("data/lang_nosp") symbol_table = k2.SymbolTable.from_file(lang_dir / "words.txt") phone_symbol_table = k2.SymbolTable.from_file(lang_dir / "phones.txt") phone_ids = get_phone_symbols(phone_symbol_table) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) if not os.path.exists(lang_dir / "HLG.pt"): print("Loading L_disambig.fst.txt") with open(lang_dir / "L_disambig.fst.txt") as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) print("Loading G.fst.txt") with open(lang_dir / "G.fst.txt") as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) first_word_disambig_id = find_first_disambig_symbol(symbol_table) HLG = compile_HLG( L=L, G=G, H=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id, ) torch.save(HLG.as_dict(), lang_dir / "HLG.pt") else: print("Loading pre-compiled HLG") d = torch.load(lang_dir / "HLG.pt") HLG = k2.Fsa.from_dict(d) # load dataset feature_dir = Path("exp/data") print("About to get test cuts") cuts_test = CutSet.from_file(feature_dir / "gigaspeech_cuts_TEST.jsonl.gz") print("About to create test dataset") test = K2SpeechRecognitionDataset(cuts_test) sampler = SingleCutSampler(cuts_test, max_frames=100000) print("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) # if not torch.cuda.is_available(): # logging.error('No GPU detected!') # sys.exit(-1) print("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device("cuda") model = TdnnLstm1b( num_features=80, num_classes=len(phone_ids) + 1, # +1 for the blank symbol subsampling_factor=4, ) checkpoint = os.path.join(exp_dir, "epoch-7.pt") load_checkpoint(checkpoint, model) model.to(device) model.eval() print("convert HLG to device") HLG = HLG.to(device) HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) HLG.requires_grad_(False) print("About to decode") results = decode(dataloader=test_dl, model=model, device=device, HLG=HLG, symbols=symbol_table) s = "" for ref, hyp in results: s += f"ref={ref}\n" s += f"hyp={hyp}\n" logging.info(s) # compute WER dists = [edit_distance(r, h) for r, h in results] errors = { key: sum(dist[key] for dist in dists) for key in ["sub", "ins", "del", "total"] } total_words = sum(len(ref) for ref, _ in results) # Print Kaldi-like message: # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ] logging.info( f'%WER {errors["total"] / total_words:.2%} ' f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' )
def __init__(self, L_inv: k2.Fsa, L_disambig: k2.Fsa, G: k2.Fsa, phones: k2.SymbolTable, words: k2.SymbolTable, device: torch.device, oov: str = '<UNK>'): ''' Args: L_inv: Its labels are words, while its aux_labels are phones. L_disambig: L with disambig symbols. Its labels are phones and aux_labels are words. G: The language model. phones: The phone symbol table. words: The word symbol table. device: The target device that all FSAs should be moved to. oov: Out of vocabulary word. ''' L_inv = L_inv.to(device) G = G.to(device) if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: L_inv = k2.arc_sort(L_inv) if G.properties & k2.fsa_properties.ARC_SORTED != 0: G = k2.arc_sort(G) assert L_inv.requires_grad is False assert G.requires_grad is False assert oov in words L = L_inv.invert() L = k2.arc_sort(L) self.L_inv = L_inv self.L = L self.phones = phones self.words = words self.device = device self.oov_id = self.words[oov] phone_symbols = get_phone_symbols(phones) phone_symbols_with_blank = [0] + phone_symbols ctc_topo = k2.arc_sort( build_ctc_topo(phone_symbols_with_blank).to(device)) assert ctc_topo.requires_grad is False self.ctc_topo = ctc_topo self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert()) lang_dir = Path('data/lang_nosp') if not (lang_dir / 'HLG_uni.pt').exists(): logging.info("Composing (ctc_topo, L_disambig, G)") first_phone_disambig_id = find_first_disambig_symbol(phones) first_word_disambig_id = find_first_disambig_symbol(words) # decoding_graph is the result of composing (ctc_topo, L_disambig, G) decoding_graph = compile_HLG( L=L_disambig.to('cpu'), G=G.to('cpu'), H=ctc_topo.to('cpu'), labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(decoding_graph.as_dict(), lang_dir / 'HLG_uni.pt') else: logging.info("Loading pre-compiled HLG") decoding_graph = k2.Fsa.from_dict( torch.load(lang_dir / 'HLG_uni.pt')) assert hasattr(decoding_graph, 'phones') self.decoding_graph = decoding_graph.to(device)
def main(): exp_dir = Path('exp-lstm-adam-ctc-musan') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') # load L, G, symbol_table lang_dir = Path('data/lang_nosp') symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') phone_ids = get_phone_symbols(phone_symbol_table) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) if not os.path.exists(lang_dir / 'LG.pt'): print("Loading L_disambig.fst.txt") with open(lang_dir / 'L_disambig.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) print("Loading G.fst.txt") with open(lang_dir / 'G.fst.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) first_phone_disambig_id = find_first_disambig_symbol( phone_symbol_table) first_word_disambig_id = find_first_disambig_symbol(symbol_table) LG = compile_LG(L=L, G=G, ctc_topo=ctc_topo, labels_disambig_id_start=first_phone_disambig_id, aux_labels_disambig_id_start=first_word_disambig_id) torch.save(LG.as_dict(), lang_dir / 'LG.pt') else: print("Loading pre-compiled LG") d = torch.load(lang_dir / 'LG.pt') LG = k2.Fsa.from_dict(d) # load dataset feature_dir = Path('exp/data') print("About to get test cuts") cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz') print("About to create test dataset") test = K2SpeechRecognitionIterableDataset(cuts_test, max_frames=100000, shuffle=False, concat_cuts=False) print("About to create test dataloader") test_dl = torch.utils.data.DataLoader(test, batch_size=None, num_workers=1) # if not torch.cuda.is_available(): # logging.error('No GPU detected!') # sys.exit(-1) print("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') model = TdnnLstm1b(num_features=40, num_classes=len(phone_ids_with_blank)) checkpoint = os.path.join(exp_dir, 'epoch-7.pt') load_checkpoint(checkpoint, model) model.to(device) model.eval() print("convert LG to device") LG = LG.to(device) LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) LG.requires_grad_(False) print("About to decode") results = decode(dataloader=test_dl, model=model, device=device, LG=LG, symbols=symbol_table) s = '' for ref, hyp in results: s += f'ref={ref}\n' s += f'hyp={hyp}\n' logging.info(s) # compute WER dists = [edit_distance(r, h) for r, h in results] errors = { key: sum(dist[key] for dist in dists) for key in ['sub', 'ins', 'del', 'total'] } total_words = sum(len(ref) for ref, _ in results) # Print Kaldi-like message: # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ] logging.info( f'%WER {errors["total"] / total_words:.2%} ' f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' )