def main():
    exp_dir = Path('exp-lstm-adam-mmi-mbr-musan')
    setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')

    phone_ids = get_phone_symbols(phone_symbol_table)
    P = create_bigram_phone_lm(phone_ids)

    phone_ids_with_blank = [0] + phone_ids
    ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

    logging.debug("About to load model")
    # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
    # device = torch.device('cuda', 1)
    device = torch.device('cuda')
    model = TdnnLstm1b(
        num_features=40,
        num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
        subsampling_factor=3)
    model.P_scores = torch.nn.Parameter(P.scores.clone(), requires_grad=False)

    checkpoint = os.path.join(exp_dir, 'epoch-9.pt')
    load_checkpoint(checkpoint, model)
    model.to(device)
    model.eval()

    assert P.requires_grad is False
    P.scores = model.P_scores.cpu()
    print_transition_probabilities(P,
                                   phone_symbol_table,
                                   phone_ids,
                                   filename='model_P_scores.txt')

    P.set_scores_stochastic_(model.P_scores)
    print_transition_probabilities(P,
                                   phone_symbol_table,
                                   phone_ids,
                                   filename='P_scores.txt')

    if not os.path.exists(lang_dir / 'HLG.pt'):
        logging.debug("Loading L_disambig.fst.txt")
        with open(lang_dir / 'L_disambig.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
        logging.debug("Loading G.fst.txt")
        with open(lang_dir / 'G.fst.txt') as f:
            G = k2.Fsa.from_openfst(f.read(), acceptor=False)
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        HLG = compile_HLG(L=L,
                          G=G,
                          H=ctc_topo,
                          labels_disambig_id_start=first_phone_disambig_id,
                          aux_labels_disambig_id_start=first_word_disambig_id)
        torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
    else:
        logging.debug("Loading pre-compiled HLG")
        d = torch.load(lang_dir / 'HLG.pt')
        HLG = k2.Fsa.from_dict(d)

    # load dataset
    feature_dir = Path('exp/data')
    logging.debug("About to get test cuts")
    cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz')

    logging.info("About to create test dataset")
    test = K2SpeechRecognitionDataset(cuts_test)
    sampler = SingleCutSampler(cuts_test, max_frames=100000)
    logging.info("About to create test dataloader")
    test_dl = torch.utils.data.DataLoader(test,
                                          batch_size=None,
                                          sampler=sampler,
                                          num_workers=1)

    #  if not torch.cuda.is_available():
    #  logging.error('No GPU detected!')
    #  sys.exit(-1)

    logging.debug("convert HLG to device")
    HLG = HLG.to(device)
    HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
    HLG.requires_grad_(False)
    logging.debug("About to decode")
    results = decode(dataloader=test_dl,
                     model=model,
                     device=device,
                     HLG=HLG,
                     symbols=symbol_table)
    s = ''
    for ref, hyp in results:
        s += f'ref={ref}\n'
        s += f'hyp={hyp}\n'
    logging.info(s)
    # compute WER
    dists = [edit_distance(r, h) for r, h in results]
    errors = {
        key: sum(dist[key] for dist in dists)
        for key in ['sub', 'ins', 'del', 'total']
    }
    total_words = sum(len(ref) for ref, _ in results)
    # Print Kaldi-like message:
    # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ]
    logging.info(
        f'%WER {errors["total"] / total_words:.2%} '
        f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]'
    )
Esempio n. 2
0
    def __init__(self,
                 L_inv: k2.Fsa,
                 L_disambig: k2.Fsa,
                 G: k2.Fsa,
                 phones: k2.SymbolTable,
                 words: k2.SymbolTable,
                 device: torch.device,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
          L_disambig:
            L with disambig symbols. Its labels are phones and aux_labels
            are words.
          G:
            The language model.
          phones:
            The phone symbol table.
          words:
            The word symbol table.
          device:
            The target device that all FSAs should be moved to.
          oov:
            Out of vocabulary word.
        '''

        L_inv = L_inv.to(device)
        G = G.to(device)

        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        if G.properties & k2.fsa_properties.ARC_SORTED != 0:
            G = k2.arc_sort(G)

        assert L_inv.requires_grad is False
        assert G.requires_grad is False

        assert oov in words

        L = L_inv.invert()
        L = k2.arc_sort(L)

        self.L_inv = L_inv
        self.L = L
        self.phones = phones
        self.words = words
        self.device = device
        self.oov_id = self.words[oov]

        phone_symbols = get_phone_symbols(phones)
        phone_symbols_with_blank = [0] + phone_symbols

        ctc_topo = k2.arc_sort(
            build_ctc_topo(phone_symbols_with_blank).to(device))
        assert ctc_topo.requires_grad is False

        self.ctc_topo = ctc_topo
        self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert())

        lang_dir = Path('data/lang_nosp')
        if not (lang_dir / 'HLG_uni.pt').exists():
            logging.info("Composing (ctc_topo, L_disambig, G)")
            first_phone_disambig_id = find_first_disambig_symbol(phones)
            first_word_disambig_id = find_first_disambig_symbol(words)
            # decoding_graph is the result of composing (ctc_topo, L_disambig, G)
            decoding_graph = compile_HLG(
                L=L_disambig.to('cpu'),
                G=G.to('cpu'),
                H=ctc_topo.to('cpu'),
                labels_disambig_id_start=first_phone_disambig_id,
                aux_labels_disambig_id_start=first_word_disambig_id)
            torch.save(decoding_graph.as_dict(), lang_dir / 'HLG_uni.pt')
        else:
            logging.info("Loading pre-compiled HLG")
            decoding_graph = k2.Fsa.from_dict(
                torch.load(lang_dir / 'HLG_uni.pt'))

        assert hasattr(decoding_graph, 'phones')

        self.decoding_graph = decoding_graph.to(device)
Esempio n. 3
0
def main():
    exp_dir = Path("exp-lstm-adam-ctc-musan")
    setup_logger("{}/log/log-decode".format(exp_dir), log_level="debug")

    # load L, G, symbol_table
    lang_dir = Path("data/lang_nosp")
    symbol_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / "phones.txt")
    phone_ids = get_phone_symbols(phone_symbol_table)
    phone_ids_with_blank = [0] + phone_ids
    ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

    if not os.path.exists(lang_dir / "HLG.pt"):
        print("Loading L_disambig.fst.txt")
        with open(lang_dir / "L_disambig.fst.txt") as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
        print("Loading G.fst.txt")
        with open(lang_dir / "G.fst.txt") as f:
            G = k2.Fsa.from_openfst(f.read(), acceptor=False)
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        HLG = compile_HLG(
            L=L,
            G=G,
            H=ctc_topo,
            labels_disambig_id_start=first_phone_disambig_id,
            aux_labels_disambig_id_start=first_word_disambig_id,
        )
        torch.save(HLG.as_dict(), lang_dir / "HLG.pt")
    else:
        print("Loading pre-compiled HLG")
        d = torch.load(lang_dir / "HLG.pt")
        HLG = k2.Fsa.from_dict(d)

    # load dataset
    feature_dir = Path("exp/data")
    print("About to get test cuts")
    cuts_test = CutSet.from_file(feature_dir / "gigaspeech_cuts_TEST.jsonl.gz")

    print("About to create test dataset")
    test = K2SpeechRecognitionDataset(cuts_test)
    sampler = SingleCutSampler(cuts_test, max_frames=100000)
    print("About to create test dataloader")
    test_dl = torch.utils.data.DataLoader(test,
                                          batch_size=None,
                                          sampler=sampler,
                                          num_workers=1)

    #  if not torch.cuda.is_available():
    #  logging.error('No GPU detected!')
    #  sys.exit(-1)

    print("About to load model")
    # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
    # device = torch.device('cuda', 1)
    device = torch.device("cuda")
    model = TdnnLstm1b(
        num_features=80,
        num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
        subsampling_factor=4,
    )

    checkpoint = os.path.join(exp_dir, "epoch-7.pt")
    load_checkpoint(checkpoint, model)
    model.to(device)
    model.eval()

    print("convert HLG to device")
    HLG = HLG.to(device)
    HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
    HLG.requires_grad_(False)
    print("About to decode")
    results = decode(dataloader=test_dl,
                     model=model,
                     device=device,
                     HLG=HLG,
                     symbols=symbol_table)
    s = ""
    for ref, hyp in results:
        s += f"ref={ref}\n"
        s += f"hyp={hyp}\n"
    logging.info(s)
    # compute WER
    dists = [edit_distance(r, h) for r, h in results]
    errors = {
        key: sum(dist[key] for dist in dists)
        for key in ["sub", "ins", "del", "total"]
    }
    total_words = sum(len(ref) for ref, _ in results)
    # Print Kaldi-like message:
    # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ]
    logging.info(
        f'%WER {errors["total"] / total_words:.2%} '
        f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]'
    )
Esempio n. 4
0
def main():
    parser = get_parser()
    LibriSpeechAsrDataModule.add_arguments(parser)
    args = parser.parse_args()

    model_type = args.model_type
    epoch = args.epoch
    avg = args.avg
    att_rate = args.att_rate
    num_paths = args.num_paths
    use_lm_rescoring = args.use_lm_rescoring
    use_whole_lattice = False
    if use_lm_rescoring and num_paths < 1:
        # It doesn't make sense to use n-best list for rescoring
        # when n is less than 1
        use_whole_lattice = True

    output_beam_size = args.output_beam_size

    exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg')
    setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

    logging.info(f'output_beam_size: {output_beam_size}')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')

    phone_ids = get_phone_symbols(phone_symbol_table)
    P = create_bigram_phone_lm(phone_ids)

    phone_ids_with_blank = [0] + phone_ids
    ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

    logging.debug("About to load model")
    # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
    # device = torch.device('cuda', 1)
    device = torch.device('cuda')

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True)
    elif model_type == "conformer":
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True)
    elif model_type == "contextnet":
        model = ContextNet(num_features=80, num_classes=len(phone_ids) +
                           1)  # +1 for the blank symbol
    else:
        raise NotImplementedError("Model of type " + str(model_type) +
                                  " is not implemented")

    model.P_scores = torch.nn.Parameter(P.scores.clone(), requires_grad=False)

    if avg == 1:
        checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt')
        load_checkpoint(checkpoint, model)
    else:
        checkpoints = [
            os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt')
            for avg_epoch in range(epoch - avg, epoch)
        ]
        average_checkpoint(checkpoints, model)

    model.to(device)
    model.eval()

    assert P.requires_grad is False
    P.scores = model.P_scores.cpu()
    print_transition_probabilities(P,
                                   phone_symbol_table,
                                   phone_ids,
                                   filename='model_P_scores.txt')

    P.set_scores_stochastic_(model.P_scores)
    print_transition_probabilities(P,
                                   phone_symbol_table,
                                   phone_ids,
                                   filename='P_scores.txt')

    if not os.path.exists(lang_dir / 'HLG.pt'):
        logging.debug("Loading L_disambig.fst.txt")
        with open(lang_dir / 'L_disambig.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
        logging.debug("Loading G.fst.txt")
        with open(lang_dir / 'G.fst.txt') as f:
            G = k2.Fsa.from_openfst(f.read(), acceptor=False)
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        HLG = compile_HLG(L=L,
                          G=G,
                          H=ctc_topo,
                          labels_disambig_id_start=first_phone_disambig_id,
                          aux_labels_disambig_id_start=first_word_disambig_id)
        torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
    else:
        logging.debug("Loading pre-compiled HLG")
        d = torch.load(lang_dir / 'HLG.pt')
        HLG = k2.Fsa.from_dict(d)

    if use_lm_rescoring:
        if use_whole_lattice:
            logging.info('Rescoring with the whole lattice')
        else:
            logging.info(f'Rescoring with n-best list, n is {num_paths}')
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        if not os.path.exists(lang_dir / 'G_4_gram.pt'):
            logging.debug('Loading G_4_gram.fst.txt')
            with open(lang_dir / 'G_4_gram.fst.txt') as f:
                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
                # G.aux_labels is not needed in later computations, so
                # remove it here.
                del G.aux_labels
                # CAUTION(fangjun): The following line is crucial.
                # Arcs entering the back-off state have label equal to #0.
                # We have to change it to 0 here.
                G.labels[G.labels >= first_word_disambig_id] = 0
                G = k2.create_fsa_vec([G]).to(device)
                G = k2.arc_sort(G)
                torch.save(G.as_dict(), lang_dir / 'G_4_gram.pt')
        else:
            logging.debug('Loading pre-compiled G_4_gram.pt')
            d = torch.load(lang_dir / 'G_4_gram.pt')
            G = k2.Fsa.from_dict(d).to(device)

        if use_whole_lattice:
            # Add epsilon self-loops to G as we will compose
            # it with the whole lattice later
            G = k2.add_epsilon_self_loops(G)
            G = k2.arc_sort(G)
            G = G.to(device)
    else:
        logging.debug('Decoding without LM rescoring')
        G = None

    logging.debug("convert HLG to device")
    HLG = HLG.to(device)
    HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
    HLG.requires_grad_(False)

    if not hasattr(HLG, 'lm_scores'):
        HLG.lm_scores = HLG.scores.clone()

    # load dataset
    librispeech = LibriSpeechAsrDataModule(args)
    test_sets = ['test-clean', 'test-other']
    #  test_sets = ['test-other']
    for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
        logging.info(f'* DECODING: {test_set}')

        results = decode(dataloader=test_dl,
                         model=model,
                         device=device,
                         HLG=HLG,
                         symbols=symbol_table,
                         num_paths=num_paths,
                         G=G,
                         use_whole_lattice=use_whole_lattice,
                         output_beam_size=output_beam_size)

        recog_path = exp_dir / f'recogs-{test_set}.txt'
        store_transcripts(path=recog_path, texts=results)
        logging.info(f'The transcripts are stored in {recog_path}')

        # The following prints out WERs, per-word error statistics and aligned
        # ref/hyp pairs.
        errs_filename = exp_dir / f'errs-{test_set}.txt'
        with open(errs_filename, 'w') as f:
            write_error_stats(f, test_set, results)
        logging.info('Wrote detailed error stats to {}'.format(errs_filename))
def main():
    parser = get_parser()
    GigaSpeechAsrDataModule.add_arguments(parser)
    args = parser.parse_args()

    model_type = args.model_type
    epoch = args.epoch
    avg = args.avg
    att_rate = args.att_rate
    num_paths = args.num_paths
    use_lm_rescoring = args.use_lm_rescoring
    use_whole_lattice = False
    if use_lm_rescoring and num_paths < 1:
        # It doesn't make sense to use n-best list for rescoring
        # when n is less than 1
        use_whole_lattice = True

    output_beam_size = args.output_beam_size

    suffix = ''
    if args.context_window is not None and args.context_window > 0:
        suffix = f'ac{args.context_window}'
    giga_subset = f'giga{args.subset}'
    exp_dir = Path(
        f'exp-{model_type}-mmi-att-sa-vgg-normlayer-{giga_subset}-{suffix}')

    setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

    logging.info(f'output_beam_size: {output_beam_size}')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')

    phone_ids = get_phone_symbols(phone_symbol_table)

    phone_ids_with_blank = [0] + phone_ids
    ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

    logging.debug("About to load model")
    # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
    # device = torch.device('cuda', 1)
    device = torch.device('cuda')

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=args.vgg_fronted)
    elif model_type == "conformer":
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=args.vgg_frontend,
            is_espnet_structure=args.is_espnet_structure)
    elif model_type == "contextnet":
        model = ContextNet(num_features=80, num_classes=len(phone_ids) +
                           1)  # +1 for the blank symbol
    else:
        raise NotImplementedError("Model of type " + str(model_type) +
                                  " is not implemented")

    if avg == 1:
        checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt')
        load_checkpoint(checkpoint, model)
    else:
        checkpoints = [
            os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt')
            for avg_epoch in range(epoch - avg, epoch)
        ]
        average_checkpoint(checkpoints, model)

    if args.torchscript:
        logging.info('Applying TorchScript to model...')
        model = torch.jit.script(model)
        ts_path = exp_dir / f'model_ts_epoch{epoch}_avg{avg}.pt'
        logging.info(f'Storing the TorchScripted model in {ts_path}')
        model.save(ts_path)

    model.to(device)
    model.eval()

    if not os.path.exists(lang_dir / 'HLG.pt'):
        logging.debug("Loading L_disambig.fst.txt")
        with open(lang_dir / 'L_disambig.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
        logging.debug("Loading G.fst.txt")
        with open(lang_dir / 'G.fst.txt') as f:
            G = k2.Fsa.from_openfst(f.read(), acceptor=False)
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        HLG = compile_HLG(L=L,
                          G=G,
                          H=ctc_topo,
                          labels_disambig_id_start=first_phone_disambig_id,
                          aux_labels_disambig_id_start=first_word_disambig_id)
        torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
    else:
        logging.debug("Loading pre-compiled HLG")
        d = torch.load(lang_dir / 'HLG.pt')
        HLG = k2.Fsa.from_dict(d)

    if use_lm_rescoring:
        if use_whole_lattice:
            logging.info('Rescoring with the whole lattice')
        else:
            logging.info(f'Rescoring with n-best list, n is {num_paths}')
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        if not os.path.exists(lang_dir / 'G_4_gram.pt'):
            logging.debug('Loading G_4_gram.fst.txt')
            with open(lang_dir / 'G_4_gram.fst.txt') as f:
                G = k2.Fsa.from_openfst(f.read(), acceptor=False)
                # G.aux_labels is not needed in later computations, so
                # remove it here.
                del G.aux_labels
                # CAUTION(fangjun): The following line is crucial.
                # Arcs entering the back-off state have label equal to #0.
                # We have to change it to 0 here.
                G.labels[G.labels >= first_word_disambig_id] = 0
                G = k2.create_fsa_vec([G]).to(device)
                G = k2.arc_sort(G)
                torch.save(G.as_dict(), lang_dir / 'G_4_gram.pt')
        else:
            logging.debug('Loading pre-compiled G_4_gram.pt')
            d = torch.load(lang_dir / 'G_4_gram.pt')
            G = k2.Fsa.from_dict(d).to(device)

        if use_whole_lattice:
            # Add epsilon self-loops to G as we will compose
            # it with the whole lattice later
            G = k2.add_epsilon_self_loops(G)
            G = k2.arc_sort(G)
            G = G.to(device)
        # G.lm_scores is used to replace HLG.lm_scores during
        # LM rescoring.
        G.lm_scores = G.scores.clone()
    else:
        logging.debug('Decoding without LM rescoring')
        G = None
        if num_paths > 1:
            logging.debug(f'Use n-best list decoding, n is {num_paths}')
        else:
            logging.debug('Use 1-best decoding')

    logging.debug("convert HLG to device")
    HLG = HLG.to(device)
    HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
    HLG.requires_grad_(False)

    if not hasattr(HLG, 'lm_scores'):
        HLG.lm_scores = HLG.scores.clone()

    # load dataset
    gigaspeech = GigaSpeechAsrDataModule(args)
    test_sets = ['DEV', 'TEST']
    for test_set, test_dl in zip(
            test_sets,
        [gigaspeech.valid_dataloaders(),
         gigaspeech.test_dataloaders()]):
        logging.info(f'* DECODING: {test_set}')

        test_set_wers = dict()
        results_dict = decode(dataloader=test_dl,
                              model=model,
                              HLG=HLG,
                              symbols=symbol_table,
                              num_paths=num_paths,
                              G=G,
                              use_whole_lattice=use_whole_lattice,
                              output_beam_size=output_beam_size)

        for key, results in results_dict.items():
            recog_path = exp_dir / f'recogs-{test_set}-{key}.txt'
            store_transcripts(path=recog_path, texts=results)
            logging.info(f'The transcripts are stored in {recog_path}')

            ref_path = exp_dir / f'ref-{test_set}.trn'
            hyp_path = exp_dir / f'hyp-{test_set}.trn'
            store_transcripts_for_sclite(ref_path=ref_path,
                                         hyp_path=hyp_path,
                                         texts=results)
            logging.info(
                f'The sclite-format transcripts are stored in {ref_path} and {hyp_path}'
            )
            cmd = f'python3 GigaSpeech/utils/gigaspeech_scoring.py {ref_path} {hyp_path} {exp_dir / "tmp_sclite"}'
            logging.info(cmd)
            try:
                subprocess.run(cmd, check=True, shell=True)
            except subprocess.CalledProcessError:
                logging.error(
                    'Skipping sclite scoring as it failed to run: Is "sclite" registered in your $PATH?"'
                )

            # The following prints out WERs, per-word error statistics and aligned
            # ref/hyp pairs.
            errs_filename = exp_dir / f'errs-{test_set}-{key}.txt'
            with open(errs_filename, 'w') as f:
                wer = write_error_stats(f, f'{test_set}-{key}', results)
                test_set_wers[key] = wer

            logging.info(
                'Wrote detailed error stats to {}'.format(errs_filename))

        test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
        errs_info = exp_dir / f'wer-summary-{test_set}.txt'
        with open(errs_info, 'w') as f:
            print('settings\tWER', file=f)
            for key, val in test_set_wers:
                print('{}\t{}'.format(key, val), file=f)

        s = '\nFor {}, WER of different settings are:\n'.format(test_set)
        note = '\tbest for {}'.format(test_set)
        for key, val in test_set_wers:
            s += '{}\t{}{}\n'.format(key, val, note)
            note = ''
        logging.info(s)
Esempio n. 6
0
def main():
    args = get_parser().parse_args()

    model_type = args.model_type
    epoch = args.epoch
    max_duration = args.max_duration
    avg = args.avg
    att_rate = args.att_rate

    exp_dir = Path('exp-' + model_type + '-noam-ctc-att-musan-sa')
    setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')

    phone_ids = get_phone_symbols(phone_symbol_table)
    phone_ids_with_blank = [0] + phone_ids
    ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

    logging.debug("About to load model")
    # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
    # device = torch.device('cuda', 1)
    device = torch.device('cuda')

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers)
    else:
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers)

    if avg == 1:
        checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt')
        load_checkpoint(checkpoint, model)
    else:
        checkpoints = [
            os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt')
            for avg_epoch in range(epoch - avg, epoch)
        ]
        average_checkpoint(checkpoints, model)

    model.to(device)
    model.eval()

    if not os.path.exists(lang_dir / 'HLG.pt'):
        logging.debug("Loading L_disambig.fst.txt")
        with open(lang_dir / 'L_disambig.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
        logging.debug("Loading G.fst.txt")
        with open(lang_dir / 'G.fst.txt') as f:
            G = k2.Fsa.from_openfst(f.read(), acceptor=False)
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        HLG = compile_HLG(L=L,
                          G=G,
                          H=ctc_topo,
                          labels_disambig_id_start=first_phone_disambig_id,
                          aux_labels_disambig_id_start=first_word_disambig_id)
        torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
    else:
        logging.debug("Loading pre-compiled HLG")
        d = torch.load(lang_dir / 'HLG.pt')
        HLG = k2.Fsa.from_dict(d)

    logging.debug("convert HLG to device")
    HLG = HLG.to(device)
    HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
    HLG.requires_grad_(False)

    # load dataset
    feature_dir = Path('exp/data')
    test_sets = ['test-clean', 'test-other']
    for test_set in test_sets:
        logging.info(f'* DECODING: {test_set}')

        logging.debug("About to get test cuts")
        cuts_test = load_manifest(feature_dir / f'cuts_{test_set}.json.gz')
        logging.debug("About to create test dataset")
        from lhotse.dataset.input_strategies import OnTheFlyFeatures
        from lhotse import Fbank, FbankConfig
        test = K2SpeechRecognitionDataset(
            cuts_test,
            input_strategy=OnTheFlyFeatures(Fbank(
                FbankConfig(num_mel_bins=80))))
        sampler = SingleCutSampler(cuts_test, max_duration=max_duration)
        logging.debug("About to create test dataloader")
        test_dl = torch.utils.data.DataLoader(test,
                                              batch_size=None,
                                              sampler=sampler,
                                              num_workers=1)

        logging.debug("About to decode")
        results = decode(dataloader=test_dl,
                         model=model,
                         device=device,
                         HLG=HLG,
                         symbols=symbol_table)

        recog_path = exp_dir / f'recogs-{test_set}.txt'
        store_transcripts(path=recog_path, texts=results)
        logging.info(f'The transcripts are stored in {recog_path}')
        # compute WER
        dists = [edit_distance(r, h) for r, h in results]
        errors = {
            key: sum(dist[key] for dist in dists)
            for key in ['sub', 'ins', 'del', 'total']
        }
        total_words = sum(len(ref) for ref, _ in results)
        # Print Kaldi-like message:
        # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ]
        logging.info(
            f'[{test_set}] %WER {errors["total"] / total_words:.2%} '
            f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]'
        )
Esempio n. 7
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    model_type = args.model_type
    epoch = args.epoch
    avg = args.avg
    att_rate = args.att_rate
    num_paths = args.num_paths
    use_lm_rescoring = args.use_lm_rescoring
    use_whole_lattice = False
    if use_lm_rescoring and num_paths < 1:
        # It doesn't make sense to use n-best list for rescoring
        # when n is less than 1
        use_whole_lattice = True

    output_beam_size = args.output_beam_size

    exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer')
    setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

    logging.info(f'output_beam_size: {output_beam_size}')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
    phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')

    phone_ids = get_phone_symbols(phone_symbol_table)

    phone_ids_with_blank = [0] + phone_ids
    ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

    logging.debug("About to load model")
    # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
    # device = torch.device('cuda', 1)
    device = torch.device('cuda')

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=40,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=args.vgg_fronted)
    elif model_type == "conformer":
        model = Conformer(
            num_features=40,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=args.vgg_frontend,
            is_espnet_structure=args.is_espnet_structure)
    elif model_type == "contextnet":
        model = ContextNet(num_features=40, num_classes=len(phone_ids) +
                           1)  # +1 for the blank symbol
    else:
        raise NotImplementedError("Model of type " + str(model_type) +
                                  " is not implemented")

    if avg == 1:
        checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt')
        load_checkpoint(checkpoint, model)
    else:
        checkpoints = [
            os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt')
            for avg_epoch in range(epoch - avg, epoch)
        ]
        average_checkpoint(checkpoints, model)

    model.to(device)
    model.eval()

    if not os.path.exists(lang_dir / 'HLG.pt'):
        logging.debug("Loading L_disambig.fst.txt")
        with open(lang_dir / 'L_disambig.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
        logging.debug("Loading G.fst.txt")
        with open(lang_dir / 'G.fst.txt') as f:
            G = k2.Fsa.from_openfst(f.read(), acceptor=False)
        first_phone_disambig_id = find_first_disambig_symbol(
            phone_symbol_table)
        first_word_disambig_id = find_first_disambig_symbol(symbol_table)
        HLG = compile_HLG(L=L,
                          G=G,
                          H=ctc_topo,
                          labels_disambig_id_start=first_phone_disambig_id,
                          aux_labels_disambig_id_start=first_word_disambig_id)
        torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
    else:
        logging.debug("Loading pre-compiled HLG")
        d = torch.load(lang_dir / 'HLG.pt')
        HLG = k2.Fsa.from_dict(d)

    logging.debug('Decoding without LM rescoring')
    G = None
    if num_paths > 1:
        logging.debug(f'Use n-best list decoding, n is {num_paths}')
    else:
        logging.debug('Use 1-best decoding')

    logging.debug("convert HLG to device")
    HLG = HLG.to(device)
    HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
    HLG.requires_grad_(False)

    if not hasattr(HLG, 'lm_scores'):
        HLG.lm_scores = HLG.scores.clone()

    # load dataset
    feature_dir = Path('exp/data')
    logging.info("About to get test cuts")
    cuts_test = CutSet.from_json(feature_dir / 'cuts_test.json.gz')
    logging.info("About to create test dataset")
    test = K2SpeechRecognitionDataset(cuts_test)
    test_sampler = SingleCutSampler(cuts_test, max_frames=12000)
    logging.info("About to create test dataloader")
    test_dl = torch.utils.data.DataLoader(test,
                                          sampler=test_sampler,
                                          batch_size=None,
                                          num_workers=1)

    logging.info("About to decode")

    results = decode(dataloader=test_dl,
                     model=model,
                     HLG=HLG,
                     symbols=symbol_table,
                     num_paths=num_paths,
                     G=G,
                     use_whole_lattice=use_whole_lattice,
                     output_beam_size=output_beam_size)

    s = ''
    results2 = []
    for ref, hyp in results:
        s += f'ref={ref}\n'
        s += f'hyp={hyp}\n'
        results2.append((list(''.join(ref)), list(''.join(hyp))))
    logging.info(s)
    # compute WER
    dists = [edit_distance(r, h) for r, h in results]
    dists2 = [edit_distance(r, h) for r, h in results2]
    errors = {
        key: sum(dist[key] for dist in dists)
        for key in ['sub', 'ins', 'del', 'total']
    }
    errors2 = {
        key: sum(dist[key] for dist in dists2)
        for key in ['sub', 'ins', 'del', 'total']
    }
    total_words = sum(len(ref) for ref, _ in results)
    total_chars = sum(len(ref) for ref, _ in results2)
    # Print Kaldi-like message:
    # %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ]
    logging.info(
        f'%WER {errors["total"] / total_words:.2%} '
        f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]'
    )
    logging.info(
        f'%CER {errors2["total"] / total_chars:.2%} '
        f'[{errors2["total"]} / {total_chars}, {errors2["ins"]} ins, {errors2["del"]} del, {errors2["sub"]} sub ]'
    )
Esempio n. 8
0
def main():
    args = get_parser().parse_args()
    exp_dir = Path('exp-lstm-adam-mmi-bigram-musan-dist')
    setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

    # load L, G, symbol_table
    lang_dir = Path('data/lang_nosp')
    lexicon = Lexicon(lang_dir)

    phone_ids = lexicon.phone_symbols()

    phone_ids_with_blank = [0] + phone_ids
    ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

    logging.debug("About to load model")
    # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
    # device = torch.device('cuda', 1)
    device = torch.device('cuda')
    model = TdnnLstm1b(
        num_features=80,
        num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
        subsampling_factor=3)

    checkpoint = os.path.join(exp_dir, f'epoch-{args.epoch}.pt')
    load_checkpoint(checkpoint, model)
    model.to(device)
    model.eval()

    if not os.path.exists(lang_dir / 'HLG.pt'):
        logging.debug("Loading L_disambig.fst.txt")
        with open(lang_dir / 'L_disambig.fst.txt') as f:
            L = k2.Fsa.from_openfst(f.read(), acceptor=False)
        logging.debug("Loading G.fst.txt")
        with open(lang_dir / 'G.fst.txt') as f:
            G = k2.Fsa.from_openfst(f.read(), acceptor=False)
        first_phone_disambig_id = find_first_disambig_symbol(lexicon.phones)
        first_word_disambig_id = find_first_disambig_symbol(lexicon.words)
        HLG = compile_HLG(L=L,
                          G=G,
                          H=ctc_topo,
                          labels_disambig_id_start=first_phone_disambig_id,
                          aux_labels_disambig_id_start=first_word_disambig_id)
        torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
    else:
        logging.debug("Loading pre-compiled LG")
        d = torch.load(lang_dir / 'HLG.pt')
        HLG = k2.Fsa.from_dict(d)

    # load dataset
    feature_dir = Path('exp/data')
    logging.debug("About to get test cuts")
    cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz')

    logging.info("About to create test dataset")
    test = K2SpeechRecognitionDataset(cuts_test)
    sampler = SingleCutSampler(cuts_test, max_frames=40000)
    logging.info("About to create test dataloader")
    test_dl = torch.utils.data.DataLoader(test,
                                          batch_size=None,
                                          sampler=sampler,
                                          num_workers=1)

    #  if not torch.cuda.is_available():
    #  logging.error('No GPU detected!')
    #  sys.exit(-1)

    logging.debug("convert HLG to device")
    HLG = HLG.to(device)
    HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
    HLG.requires_grad_(False)

    logging.debug("About to decode")
    results = decode(dataloader=test_dl,
                     model=model,
                     device=device,
                     HLG=HLG,
                     symbols=lexicon.words)

    test_set = 'test-clean'
    recog_path = exp_dir / f'recogs-{test_set}.txt'
    store_transcripts(path=recog_path, texts=results)
    logging.info(f'The transcripts are stored in {recog_path}')

    # The following prints out WERs, per-word error statistics and aligned
    # ref/hyp pairs.
    errs_filename = exp_dir / f'errs-{test_set}.txt'
    with open(errs_filename, 'w') as f:
        wer = write_error_stats(f, f'{test_set}', results)
    logging.info(f'The error stats are stored in {errs_filename}')