Beispiel #1
0
    def __init__(self,
                 L_inv: k2.Fsa,
                 phones: k2.SymbolTable,
                 words: k2.SymbolTable,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
        phones:
          The phone symbol table.
        words:
          The word symbol table.
        oov:
          Out of vocabulary word.
        '''
        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        assert oov in words

        self.L_inv = L_inv
        self.phones = phones
        self.words = words
        self.oov = oov
        phone_ids = get_phone_symbols(phones)
        phone_ids_with_blank = [0] + phone_ids
        self.ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))
Beispiel #2
0
    def __init__(self,
                 lexicon: Lexicon,
                 P: k2.Fsa,
                 device: torch.device,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
          P:
            A phone bigram LM if the pronunciations in the lexicon are in phones;
            a word piece bigram if the pronunciations in the lexicon are word pieces.
          phones:
            The phone symbol table.
          words:
            The word symbol table.
          oov:
            Out of vocabulary word.
        '''
        self.lexicon = lexicon
        L_inv = self.lexicon.L_inv.to(device)
        P = P.to(device)

        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        assert L_inv.requires_grad is False

        assert oov in self.lexicon.words

        self.L_inv = L_inv
        self.oov_id = self.lexicon.words[oov]
        self.oov = oov
        self.device = device

        phone_symbols = get_phone_symbols(self.lexicon.phones)
        phone_symbols_with_blank = [0] + phone_symbols

        ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device)
        assert ctc_topo.requires_grad is False

        ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())

        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()

        self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
Beispiel #3
0
    def __init__(self,
                 lexicon: Lexicon,
                 device: torch.device,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
        phones:
          The phone symbol table.
        words:
          The word symbol table.
        oov:
          Out of vocabulary word.
        '''
        self.lexicon = lexicon
        L_inv = self.lexicon.L_inv.to(device)

        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        assert L_inv.requires_grad is False

        assert oov in self.lexicon.words

        self.L_inv = L_inv
        self.oov_id = self.lexicon.words[oov]
        self.oov = oov
        self.device = device

        phone_symbols = get_phone_symbols(self.lexicon.phones)
        phone_symbols_with_blank = [0] + phone_symbols

        ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device)
        assert ctc_topo.requires_grad is False

        self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
Beispiel #4
0
def main():
    fix_random_seed(42)

    start_epoch = 0
    num_epochs = 8

    exp_dir = 'exp-lstm-adam-ctc-musan'
    setup_logger('{}/log/log-train'.format(exp_dir))
    tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
    word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')

    logging.info("Loading L.fst")
    if (lang_dir / 'Linv.pt').exists():
        L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt'))
    else:
        with open(lang_dir / 'L.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
            L_inv = k2.arc_sort(L.invert_())
            torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt')

    graph_compiler = CtcTrainingGraphCompiler(
        L_inv=L_inv,
        phones=phone_symbol_table,
        words=word_symbol_table
    )
    phone_ids = get_phone_symbols(phone_symbol_table)

    # load dataset
    feature_dir = Path('exp/data')
    logging.info("About to get train cuts")
    cuts_train = CutSet.from_json(feature_dir /
                                  'cuts_train-clean-100.json.gz')
    logging.info("About to get dev cuts")
    cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev-clean.json.gz')
    logging.info("About to get Musan cuts")
    cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz')

    logging.info("About to create train dataset")
    train = K2SpeechRecognitionDataset(
        cuts_train,
        cut_transforms=[
            CutConcatenate(),
            CutMix(
                cuts=cuts_musan,
                prob=0.5,
                snr=(10, 20)
            )
        ]
    )
    train_sampler = SingleCutSampler(
        cuts_train,
        max_frames=90000,
        shuffle=True,
    )
    logging.info("About to create train dataloader")
    train_dl = torch.utils.data.DataLoader(
        train,
        sampler=train_sampler,
        batch_size=None,
        num_workers=4
    )
    logging.info("About to create dev dataset")
    validate = K2SpeechRecognitionDataset(cuts_dev)
    valid_sampler = SingleCutSampler(cuts_dev, max_frames=90000)
    logging.info("About to create dev dataloader")
    valid_dl = torch.utils.data.DataLoader(
        validate,
        sampler=valid_sampler,
        batch_size=None,
        num_workers=1
    )

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")
    device_id = 0
    device = torch.device('cuda', device_id)
    model = TdnnLstm1b(
        num_features=40,
        num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
        subsampling_factor=3)
    
    model.to(device)
    describe(model)

    learning_rate = 1e-3
    optimizer = optim.AdamW(model.parameters(),
                            lr=learning_rate,
                            weight_decay=5e-4)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}")

    for epoch in range(start_epoch, num_epochs):
        train_sampler.set_epoch(epoch)
        curr_learning_rate = 1e-3
        # curr_learning_rate = learning_rate * pow(0.4, epoch)
        # for param_group in optimizer.param_groups:
        #     param_group['lr'] = curr_learning_rate

        tb_writer.add_scalar('learning_rate', curr_learning_rate, epoch)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(dataloader=train_dl,
                                                                   valid_dataloader=valid_dl,
                                                                   model=model,
                                                                   device=device,
                                                                   graph_compiler=graph_compiler,
                                                                   optimizer=optimizer,
                                                                   current_epoch=epoch,
                                                                   tb_writer=tb_writer,
                                                                   num_epochs=num_epochs,
                                                                   global_batch_idx_train=global_batch_idx_train)
        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            model=model,
                            epoch=epoch,
                            optimizer=None,
                            scheduler=None,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=best_objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        model=model,
                        optimizer=optimizer,
                        scheduler=None,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
Beispiel #5
0
def main():
    exp_dir = Path("exp-lstm-adam-ctc-musan")
    setup_logger("{}/log/log-decode".format(exp_dir), log_level="debug")

    # load L, G, symbol_table
    lang_dir = Path("data/lang_nosp")
    symbol_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / "phones.txt")
    phone_ids = get_phone_symbols(phone_symbol_table)
    phone_ids_with_blank = [0] + phone_ids
    ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

    if not os.path.exists(lang_dir / "HLG.pt"):
        print("Loading L_disambig.fst.txt")
        with open(lang_dir / "L_disambig.fst.txt") as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
        print("Loading G.fst.txt")
        with open(lang_dir / "G.fst.txt") as f:
            G = k2.Fsa.from_openfst(f.read(), acceptor=False)
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        HLG = compile_HLG(
            L=L,
            G=G,
            H=ctc_topo,
            labels_disambig_id_start=first_phone_disambig_id,
            aux_labels_disambig_id_start=first_word_disambig_id,
        )
        torch.save(HLG.as_dict(), lang_dir / "HLG.pt")
    else:
        print("Loading pre-compiled HLG")
        d = torch.load(lang_dir / "HLG.pt")
        HLG = k2.Fsa.from_dict(d)

    # load dataset
    feature_dir = Path("exp/data")
    print("About to get test cuts")
    cuts_test = CutSet.from_file(feature_dir / "gigaspeech_cuts_TEST.jsonl.gz")

    print("About to create test dataset")
    test = K2SpeechRecognitionDataset(cuts_test)
    sampler = SingleCutSampler(cuts_test, max_frames=100000)
    print("About to create test dataloader")
    test_dl = torch.utils.data.DataLoader(test,
                                          batch_size=None,
                                          sampler=sampler,
                                          num_workers=1)

    #  if not torch.cuda.is_available():
    #  logging.error('No GPU detected!')
    #  sys.exit(-1)

    print("About to load model")
    # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
    # device = torch.device('cuda', 1)
    device = torch.device("cuda")
    model = TdnnLstm1b(
        num_features=80,
        num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
        subsampling_factor=4,
    )

    checkpoint = os.path.join(exp_dir, "epoch-7.pt")
    load_checkpoint(checkpoint, model)
    model.to(device)
    model.eval()

    print("convert HLG to device")
    HLG = HLG.to(device)
    HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
    HLG.requires_grad_(False)
    print("About to decode")
    results = decode(dataloader=test_dl,
                     model=model,
                     device=device,
                     HLG=HLG,
                     symbols=symbol_table)
    s = ""
    for ref, hyp in results:
        s += f"ref={ref}\n"
        s += f"hyp={hyp}\n"
    logging.info(s)
    # compute WER
    dists = [edit_distance(r, h) for r, h in results]
    errors = {
        key: sum(dist[key] for dist in dists)
        for key in ["sub", "ins", "del", "total"]
    }
    total_words = sum(len(ref) for ref, _ in results)
    # Print Kaldi-like message:
    # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ]
    logging.info(
        f'%WER {errors["total"] / total_words:.2%} '
        f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]'
    )
Beispiel #6
0
    def __init__(self,
                 L_inv: k2.Fsa,
                 L_disambig: k2.Fsa,
                 G: k2.Fsa,
                 phones: k2.SymbolTable,
                 words: k2.SymbolTable,
                 device: torch.device,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
          L_disambig:
            L with disambig symbols. Its labels are phones and aux_labels
            are words.
          G:
            The language model.
          phones:
            The phone symbol table.
          words:
            The word symbol table.
          device:
            The target device that all FSAs should be moved to.
          oov:
            Out of vocabulary word.
        '''

        L_inv = L_inv.to(device)
        G = G.to(device)

        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        if G.properties & k2.fsa_properties.ARC_SORTED != 0:
            G = k2.arc_sort(G)

        assert L_inv.requires_grad is False
        assert G.requires_grad is False

        assert oov in words

        L = L_inv.invert()
        L = k2.arc_sort(L)

        self.L_inv = L_inv
        self.L = L
        self.phones = phones
        self.words = words
        self.device = device
        self.oov_id = self.words[oov]

        phone_symbols = get_phone_symbols(phones)
        phone_symbols_with_blank = [0] + phone_symbols

        ctc_topo = k2.arc_sort(
            build_ctc_topo(phone_symbols_with_blank).to(device))
        assert ctc_topo.requires_grad is False

        self.ctc_topo = ctc_topo
        self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert())

        lang_dir = Path('data/lang_nosp')
        if not (lang_dir / 'HLG_uni.pt').exists():
            logging.info("Composing (ctc_topo, L_disambig, G)")
            first_phone_disambig_id = find_first_disambig_symbol(phones)
            first_word_disambig_id = find_first_disambig_symbol(words)
            # decoding_graph is the result of composing (ctc_topo, L_disambig, G)
            decoding_graph = compile_HLG(
                L=L_disambig.to('cpu'),
                G=G.to('cpu'),
                H=ctc_topo.to('cpu'),
                labels_disambig_id_start=first_phone_disambig_id,
                aux_labels_disambig_id_start=first_word_disambig_id)
            torch.save(decoding_graph.as_dict(), lang_dir / 'HLG_uni.pt')
        else:
            logging.info("Loading pre-compiled HLG")
            decoding_graph = k2.Fsa.from_dict(
                torch.load(lang_dir / 'HLG_uni.pt'))

        assert hasattr(decoding_graph, 'phones')

        self.decoding_graph = decoding_graph.to(device)
Beispiel #7
0
def main():
    exp_dir = Path('exp-lstm-adam-ctc-musan')
    setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
    phone_ids = get_phone_symbols(phone_symbol_table)
    phone_ids_with_blank = [0] + phone_ids
    ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

    if not os.path.exists(lang_dir / 'LG.pt'):
        print("Loading L_disambig.fst.txt")
        with open(lang_dir / 'L_disambig.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
        print("Loading G.fst.txt")
        with open(lang_dir / 'G.fst.txt') as f:
            G = k2.Fsa.from_openfst(f.read(), acceptor=False)
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        LG = compile_LG(L=L,
                        G=G,
                        ctc_topo=ctc_topo,
                        labels_disambig_id_start=first_phone_disambig_id,
                        aux_labels_disambig_id_start=first_word_disambig_id)
        torch.save(LG.as_dict(), lang_dir / 'LG.pt')
    else:
        print("Loading pre-compiled LG")
        d = torch.load(lang_dir / 'LG.pt')
        LG = k2.Fsa.from_dict(d)

    # load dataset
    feature_dir = Path('exp/data')
    print("About to get test cuts")
    cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz')

    print("About to create test dataset")
    test = K2SpeechRecognitionIterableDataset(cuts_test,
                                              max_frames=100000,
                                              shuffle=False,
                                              concat_cuts=False)
    print("About to create test dataloader")
    test_dl = torch.utils.data.DataLoader(test, batch_size=None, num_workers=1)

    #  if not torch.cuda.is_available():
    #  logging.error('No GPU detected!')
    #  sys.exit(-1)

    print("About to load model")
    # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
    # device = torch.device('cuda', 1)
    device = torch.device('cuda')
    model = TdnnLstm1b(num_features=40, num_classes=len(phone_ids_with_blank))
    checkpoint = os.path.join(exp_dir, 'epoch-7.pt')
    load_checkpoint(checkpoint, model)
    model.to(device)
    model.eval()

    print("convert LG to device")
    LG = LG.to(device)
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)
    LG.requires_grad_(False)
    print("About to decode")
    results = decode(dataloader=test_dl,
                     model=model,
                     device=device,
                     LG=LG,
                     symbols=symbol_table)
    s = ''
    for ref, hyp in results:
        s += f'ref={ref}\n'
        s += f'hyp={hyp}\n'
    logging.info(s)
    # compute WER
    dists = [edit_distance(r, h) for r, h in results]
    errors = {
        key: sum(dist[key] for dist in dists)
        for key in ['sub', 'ins', 'del', 'total']
    }
    total_words = sum(len(ref) for ref, _ in results)
    # Print Kaldi-like message:
    # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ]
    logging.info(
        f'%WER {errors["total"] / total_words:.2%} '
        f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]'
    )