コード例 #1
0
 def test_nonzero_unk(self):
     kaldi_vocab = StringIO(""" a 0
                                 <unk> 1
                                 b 2 """)
     vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
     self.assertEqual(len(vocabulary), 3)
     self.assertEqual(vocabulary['<unk>'], 1)
     self.assertEqual(vocabulary['a'], 0)
     self.assertEqual(vocabulary['b'], 2)
     self.assertEqual(vocabulary['nonexistent'], 1)
コード例 #2
0
 def test_non_continuous(self):
     kaldi_vocab = StringIO(""" <unk> 0
                                 a 3
                                 b 7 """)
     vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
     self.assertEqual(len(vocabulary), 3)
     self.assertEqual(vocabulary['<unk>'], 0)
     self.assertEqual(vocabulary['a'], 3)
     self.assertEqual(vocabulary['b'], 7)
     self.assertEqual(vocabulary['nonexistent'], 0)
コード例 #3
0
                        type=float,
                        default=0.2,
                        help='dropout applied to layers (0 = no dropout)')
    parser.add_argument('--tied',
                        action='store_true',
                        help='tie the word embedding and softmax weights')
    parser.add_argument('--seed', type=int, default=1111, help='random seed')
    parser.add_argument('--save',
                        type=str,
                        required=True,
                        help='path to save the final model')
    args = parser.parse_args()

    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)

    print("loading vocabulary...")
    with open(args.wordlist, 'r') as f:
        vocabulary = vocab.vocab_from_kaldi_wordlist(f, args.unk)

    print("building model...")

    model = ffnn_models.BengioModelIvecInput(len(vocabulary), args.emsize,
                                             args.hist_len, args.nhid,
                                             args.dropout, args.ivec_dim)

    decoder = FullSoftmaxDecoder(args.nhid, len(vocabulary))

    lm = language_model.LanguageModel(model, decoder, vocabulary)
    torch.save(lm, args.save)
コード例 #4
0
 def test_continuity_test_negative(self):
     kaldi_vocab = StringIO(""" <unk> 0
                                 a 1
                                 b 3 """)
     vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
     self.assertFalse(vocabulary.is_continuous())
コード例 #5
0
 def test_malformed_line(self):
     kaldi_vocab = StringIO(""" a 0 junk
                                 <unk> 1
                                 b 2 """)
     with self.assertRaises(ValueError):
         vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
コード例 #6
0
 def test_unk_not_present(self):
     kaldi_vocab = StringIO(""" a 0
                                 b 1 """)
     with self.assertRaises(ValueError):
         vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
コード例 #7
0
 def test_missing_indexes_beginning(self):
     kaldi_vocab = StringIO(""" <unk> 1
                                 a 2
                                 b 3 """)
     vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
     self.assertEqual(vocabulary.missing_indexes(), [0])
コード例 #8
0
 def test_missing_indexes_middle(self):
     kaldi_vocab = StringIO(""" <unk> 0
                                 a 1
                                 b 3 """)
     vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
     self.assertEqual(vocabulary.missing_indexes(), [2])
コード例 #9
0
    return transcript


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--unk', default="<UNK>")
    parser.add_argument('--unk-oi', default="<UNK-OI>")
    parser.add_argument('--oov-start', required=True)
    parser.add_argument('--oov-end', required=True)
    parser.add_argument('--interest-constant', type=int, required=True)
    parser.add_argument('--decoder-wordlist', required=True)
    args = parser.parse_args()

    with open(args.decoder_wordlist) as f:
        decoder_vocabulary = vocab_from_kaldi_wordlist(f, unk_word=args.unk)

    oov_start_idx = decoder_vocabulary[args.oov_start]
    oov_end_idx = decoder_vocabulary[args.oov_end]

    for line_no, line in enumerate(sys.stdin):
        fields = line.split()
        key = fields[0]
        idxes = [int(idx) for idx in fields[1:]]

        try:
            transcript = words_from_idx(idxes)
        except ValueError:
            sys.stderr.write("WARNING: there was a problem with input line {} (counting from 0)\n".format(line_no))
            continue
コード例 #10
0
ファイル: rescore-kaldi-latt.py プロジェクト: ibenes/pyth-lm
def main():
    parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model')
    parser.add_argument('--latt-vocab', type=str, required=True,
                        help='word -> int map; Kaldi style "words.txt"')
    parser.add_argument('--latt-unk', type=str, default='<unk>',
                        help='unk symbol used in the lattice')
    parser.add_argument('--cuda', action='store_true',
                        help='use CUDA')
    parser.add_argument('--character-lm', action='store_true',
                        help='Process strings by characters')
    parser.add_argument('--model-from', type=str, required=True,
                        help='where to load the model from')
    parser.add_argument('in_filename', help='second output of nbest-to-linear, textual')
    parser.add_argument('out_filename', help='where to put the LM scores')
    args = parser.parse_args()

    print(args)

    mode = 'chars' if args.character_lm else 'words'

    print("reading lattice vocab...")
    with open(args.latt_vocab, 'r') as f:
        latt_vocab = vocab.vocab_from_kaldi_wordlist(f, unk_word=args.latt_unk)

    print("reading model...")
    lm = torch.load(args.model_from, map_location='cpu')
    if args.cuda:
        lm.model.cuda()
    lm.model.eval()

    print("scoring...")
    curr_seg = ''
    segment_utts: typing.Dict[str, typing.Any] = {}

    with open(args.in_filename) as in_f, open(args.out_filename, 'w') as out_f:
        for line in in_f:
            fields = line.split()
            segment, trans_id = balls.kaldi_itf.split_nbest_key(fields[0])

            word_ids = [int(wi) for wi in fields[1:]]
            ids = translate_latt_to_model(word_ids, latt_vocab, lm.vocab, mode)

            if not curr_seg:
                curr_seg = segment

            if segment != curr_seg:
                X, rev_map = dict_to_list(segment_utts)  # reform the word sequences
                y = seqs_logprob(X, lm)  # score

                # write
                for i, log_p in enumerate(y):
                    out_f.write(curr_seg + '-' + rev_map[i] + ' ' + str(-log_p.item()) + '\n')

                curr_seg = segment
                segment_utts = {}

            segment_utts[trans_id] = ids

        # Last segment:
        X, rev_map = dict_to_list(segment_utts)  # reform the word sequences
        y = seqs_logprob(X, lm)  # score

        # write
        for i, log_p in enumerate(y):
            out_f.write(curr_seg + '-' + rev_map[i] + ' ' + str(-log_p.item()) + '\n')