def test_duplicate(self): kaldi_vocab = StringIO(""" <unk> 0 a 1 b 2 a 3 """) with self.assertRaises(ValueError): vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
def test_non_continuous(self): kaldi_vocab = StringIO(""" <unk> 0 a 3 b 7 """) vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>") self.assertEqual(len(vocabulary), 3) self.assertEqual(vocabulary['<unk>'], 0) self.assertEqual(vocabulary['a'], 3) self.assertEqual(vocabulary['b'], 7) self.assertEqual(vocabulary['nonexistent'], 0)
def test_nonzero_unk(self): kaldi_vocab = StringIO(""" a 0 <unk> 1 b 2 """) vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>") self.assertEqual(len(vocabulary), 3) self.assertEqual(vocabulary['<unk>'], 1) self.assertEqual(vocabulary['a'], 0) self.assertEqual(vocabulary['b'], 2) self.assertEqual(vocabulary['nonexistent'], 1)
type=float, default=0.2, help='dropout applied to layers (0 = no dropout)') parser.add_argument('--tied', action='store_true', help='tie the word embedding and softmax weights') parser.add_argument('--seed', type=int, default=1111, help='random seed') parser.add_argument('--save', type=str, required=True, help='path to save the final model') args = parser.parse_args() # Set the random seed manually for reproducibility. torch.manual_seed(args.seed) print("loading vocabulary...") with open(args.wordlist, 'r') as f: vocabulary = vocab.vocab_from_kaldi_wordlist(f, args.unk) print("building model...") model = ffnn_models.BengioModelIvecInput(len(vocabulary), args.emsize, args.hist_len, args.nhid, args.dropout, args.ivec_dim) decoder = FullSoftmaxDecoder(args.nhid, len(vocabulary)) lm = language_model.LanguageModel(model, decoder, vocabulary) torch.save(lm, args.save)
return transcript if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--unk', default="<UNK>") parser.add_argument('--unk-oi', default="<UNK-OI>") parser.add_argument('--oov-start', required=True) parser.add_argument('--oov-end', required=True) parser.add_argument('--interest-constant', type=int, required=True) parser.add_argument('--decoder-wordlist', required=True) args = parser.parse_args() with open(args.decoder_wordlist) as f: decoder_vocabulary = vocab_from_kaldi_wordlist(f, unk_word=args.unk) oov_start_idx = decoder_vocabulary[args.oov_start] oov_end_idx = decoder_vocabulary[args.oov_end] for line_no, line in enumerate(sys.stdin): fields = line.split() key = fields[0] idxes = [int(idx) for idx in fields[1:]] try: transcript = words_from_idx(idxes) except ValueError: sys.stderr.write( "WARNING: there was a problem with input line {} (counting from 0)\n" .format(line_no))
def test_continuity_test_negative(self): kaldi_vocab = StringIO(""" <unk> 0 a 1 b 3 """) vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>") self.assertFalse(vocabulary.is_continuous())
def test_malformed_line(self): kaldi_vocab = StringIO(""" a 0 junk <unk> 1 b 2 """) with self.assertRaises(ValueError): vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
def test_unk_not_present(self): kaldi_vocab = StringIO(""" a 0 b 1 """) with self.assertRaises(ValueError): vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>")
def test_missing_indexes_beginning(self): kaldi_vocab = StringIO(""" <unk> 1 a 2 b 3 """) vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>") self.assertEqual(vocabulary.missing_indexes(), [0])
def test_missing_indexes_middle(self): kaldi_vocab = StringIO(""" <unk> 0 a 1 b 3 """) vocabulary = vocab.vocab_from_kaldi_wordlist(kaldi_vocab, "<unk>") self.assertEqual(vocabulary.missing_indexes(), [2])
def main(args): logging.info(args) mode = 'chars' if args.character_lm else 'words' logging.info("reading lattice vocab...") with open(args.latt_vocab, 'r') as f: latt_vocab = vocab.vocab_from_kaldi_wordlist(f, unk_word=args.latt_unk) logging.info("reading model...") device = torch.device('cuda') if args.cuda else torch.device('cpu') lm = torch.load(args.model_from, map_location=device) lm.eval() curr_seg = '' segment_utts: typing.Dict[str, typing.Any] = {} custom_h0 = None nb_carry_overs = 0 nb_new_hs = 0 with open(args.in_filename) as in_f, open(args.out_filename, 'w') as out_f: scorer = SegmentScorer(lm, out_f) for line in in_f: fields = line.split() segment, trans_id = brnolm.kaldi_itf.split_nbest_key(fields[0]) word_ids = [int(wi) for wi in fields[1:]] words = translate_latt_to_model(word_ids, latt_vocab, lm.vocab, mode) if not curr_seg: curr_seg = segment if segment != curr_seg: result = scorer.process_segment(curr_seg, segment_utts, custom_h0) if args.carry_over == 'always': custom_h0 = select_hidden_state_to_pass(result.hidden_states) nb_carry_overs += 1 elif args.carry_over == 'speaker': if spk_sess(segment) == spk_sess(curr_seg): custom_h0 = select_hidden_state_to_pass(result.hidden_states) nb_carry_overs += 1 else: custom_h0 = None nb_new_hs += 1 elif args.carry_over == 'never': custom_h0 = None nb_new_hs += 1 else: raise ValueError(f'Unsupported carry over regime {args.carry_over}') for hyp_no, cost in result.scores.items(): out_f.write(f"{curr_seg}-{hyp_no} {cost}\n") curr_seg = segment segment_utts = {} segment_utts[trans_id] = words # Last segment: result = scorer.process_segment(curr_seg, segment_utts) for hyp_no, cost in result.scores.items(): out_f.write(f"{curr_seg}-{hyp_no} {cost}\n") logging.info(f'Hidden state was carried over {nb_carry_overs} times and reset {nb_new_hs} times')