Exemple #1
0
 def test_duplicate(self):
     kaldi_vocab = StringIO(""" <unk> 0
                                 a 1
                                 b 2
                                 a 3 """)
     with self.assertRaises(ValueError):
         vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
Exemple #2
0
 def test_non_continuous(self):
     kaldi_vocab = StringIO(""" <unk> 0
                                 a 3
                                 b 7 """)
     vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
     self.assertEqual(len(vocabulary), 3)
     self.assertEqual(vocabulary['<unk>'], 0)
     self.assertEqual(vocabulary['a'], 3)
     self.assertEqual(vocabulary['b'], 7)
     self.assertEqual(vocabulary['nonexistent'], 0)
Exemple #3
0
 def test_nonzero_unk(self):
     kaldi_vocab = StringIO(""" a 0
                                 <unk> 1
                                 b 2 """)
     vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
     self.assertEqual(len(vocabulary), 3)
     self.assertEqual(vocabulary['<unk>'], 1)
     self.assertEqual(vocabulary['a'], 0)
     self.assertEqual(vocabulary['b'], 2)
     self.assertEqual(vocabulary['nonexistent'], 1)
                        type=float,
                        default=0.2,
                        help='dropout applied to layers (0 = no dropout)')
    parser.add_argument('--tied',
                        action='store_true',
                        help='tie the word embedding and softmax weights')
    parser.add_argument('--seed', type=int, default=1111, help='random seed')
    parser.add_argument('--save',
                        type=str,
                        required=True,
                        help='path to save the final model')
    args = parser.parse_args()

    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)

    print("loading vocabulary...")
    with open(args.wordlist, 'r') as f:
        vocabulary = vocab.vocab_from_kaldi_wordlist(f, args.unk)

    print("building model...")

    model = ffnn_models.BengioModelIvecInput(len(vocabulary), args.emsize,
                                             args.hist_len, args.nhid,
                                             args.dropout, args.ivec_dim)

    decoder = FullSoftmaxDecoder(args.nhid, len(vocabulary))

    lm = language_model.LanguageModel(model, decoder, vocabulary)
    torch.save(lm, args.save)
Exemple #5
0
    return transcript


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--unk', default="<UNK>")
    parser.add_argument('--unk-oi', default="<UNK-OI>")
    parser.add_argument('--oov-start', required=True)
    parser.add_argument('--oov-end', required=True)
    parser.add_argument('--interest-constant', type=int, required=True)
    parser.add_argument('--decoder-wordlist', required=True)
    args = parser.parse_args()

    with open(args.decoder_wordlist) as f:
        decoder_vocabulary = vocab_from_kaldi_wordlist(f, unk_word=args.unk)

    oov_start_idx = decoder_vocabulary[args.oov_start]
    oov_end_idx = decoder_vocabulary[args.oov_end]

    for line_no, line in enumerate(sys.stdin):
        fields = line.split()
        key = fields[0]
        idxes = [int(idx) for idx in fields[1:]]

        try:
            transcript = words_from_idx(idxes)
        except ValueError:
            sys.stderr.write(
                "WARNING: there was a problem with input line {} (counting from 0)\n"
                .format(line_no))
Exemple #6
0
 def test_continuity_test_negative(self):
     kaldi_vocab = StringIO(""" <unk> 0
                                 a 1
                                 b 3 """)
     vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
     self.assertFalse(vocabulary.is_continuous())
Exemple #7
0
 def test_malformed_line(self):
     kaldi_vocab = StringIO(""" a 0 junk
                                 <unk> 1
                                 b 2 """)
     with self.assertRaises(ValueError):
         vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
Exemple #8
0
 def test_unk_not_present(self):
     kaldi_vocab = StringIO(""" a 0
                                 b 1 """)
     with self.assertRaises(ValueError):
         vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
Exemple #9
0
 def test_missing_indexes_beginning(self):
     kaldi_vocab = StringIO(""" <unk> 1
                                 a 2
                                 b 3 """)
     vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
     self.assertEqual(vocabulary.missing_indexes(), [0])
Exemple #10
0
 def test_missing_indexes_middle(self):
     kaldi_vocab = StringIO(""" <unk> 0
                                 a 1
                                 b 3 """)
     vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
     self.assertEqual(vocabulary.missing_indexes(), [2])
Exemple #11
0
def main(args):
    logging.info(args)

    mode = 'chars' if args.character_lm else 'words'

    logging.info("reading lattice vocab...")

    with open(args.latt_vocab, 'r') as f:
        latt_vocab = vocab.vocab_from_kaldi_wordlist(f, unk_word=args.latt_unk)

    logging.info("reading model...")
    device = torch.device('cuda') if args.cuda else torch.device('cpu')
    lm = torch.load(args.model_from, map_location=device)

    lm.eval()

    curr_seg = ''
    segment_utts: typing.Dict[str, typing.Any] = {}

    custom_h0 = None
    nb_carry_overs = 0
    nb_new_hs = 0

    with open(args.in_filename) as in_f, open(args.out_filename, 'w') as out_f:
        scorer = SegmentScorer(lm, out_f)

        for line in in_f:
            fields = line.split()
            segment, trans_id = brnolm.kaldi_itf.split_nbest_key(fields[0])

            word_ids = [int(wi) for wi in fields[1:]]
            words = translate_latt_to_model(word_ids, latt_vocab, lm.vocab, mode)

            if not curr_seg:
                curr_seg = segment

            if segment != curr_seg:
                result = scorer.process_segment(curr_seg, segment_utts, custom_h0)
                if args.carry_over == 'always':
                    custom_h0 = select_hidden_state_to_pass(result.hidden_states)
                    nb_carry_overs += 1
                elif args.carry_over == 'speaker':
                    if spk_sess(segment) == spk_sess(curr_seg):
                        custom_h0 = select_hidden_state_to_pass(result.hidden_states)
                        nb_carry_overs += 1
                    else:
                        custom_h0 = None
                        nb_new_hs += 1
                elif args.carry_over == 'never':
                    custom_h0 = None
                    nb_new_hs += 1
                else:
                    raise ValueError(f'Unsupported carry over regime {args.carry_over}')
                for hyp_no, cost in result.scores.items():
                    out_f.write(f"{curr_seg}-{hyp_no} {cost}\n")

                curr_seg = segment
                segment_utts = {}

            segment_utts[trans_id] = words

        # Last segment:
        result = scorer.process_segment(curr_seg, segment_utts)
        for hyp_no, cost in result.scores.items():
            out_f.write(f"{curr_seg}-{hyp_no} {cost}\n")

    logging.info(f'Hidden state was carried over {nb_carry_overs} times and reset {nb_new_hs} times')