def __init__(self, word_lm: FairseqLanguageModel, subword_dict: TokenDictionary, oov_penalty: float = 1e-4, open_vocab: bool = True): super().__init__(word_lm.decoder.dictionary) self.lm_decoder: FairseqIncrementalDecoder = word_lm.decoder assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \ callable(self.lm_decoder.masked_copy_incremental_state), \ 'The wrapped decoder should implement masked_copy_incremental_state()' self.oov_penalty = oov_penalty self.open_vocab = open_vocab self.zero = 1e-10 # a sufficiently small value to avoid the log(0) issue word_dict: TokenDictionary = self.lm_decoder.dictionary self.word_pad_idx = word_dict.pad() self.word_eos_idx = word_dict.eos() self.word_unk_idx = word_dict.unk() self.subword_space_idx = subword_dict.space() self.subword_pad_idx = subword_dict.pad() self.subword_eos_idx = subword_dict.eos() self.subword_vocab_size = len(subword_dict) tokenizer: Callable[[str], List[str]] = \ lambda x: tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') self.tree = TensorizedPrefixTree.build(word_dict, subword_dict, tokenizer) assert self.tree.max_out_degree() <= self.subword_vocab_size
def __init__(self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, open_vocab=True): super().__init__(wordlm.decoder.dictionary) assert isinstance(wordlm, FairseqLanguageModel) self.wordlm_decoder = wordlm.decoder assert hasattr(self.wordlm_decoder, 'masked_copy_incremental_state') and \ callable(self.wordlm_decoder.masked_copy_incremental_state), \ 'The wrapped decoder should implement masked_copy_incremental_state()' assert isinstance(subwordlm, FairseqLanguageModel) self.subwordlm_decoder = subwordlm.decoder self.subwordlm_weight = subwordlm_weight self.log_oov_penalty = math.log(oov_penalty) self.open_vocab = open_vocab self.logzero = -10.0 word_dict = self.wordlm_decoder.dictionary assert isinstance(word_dict, TokenDictionary) self.word_eos_idx = word_dict.eos() self.word_unk_idx = word_dict.unk() subword_dict = self.subwordlm_decoder.dictionary assert isinstance(subword_dict, TokenDictionary) self.subword_space_idx = subword_dict.space() self.subword_eos_idx = subword_dict.eos() self.subword_vocab_size = len(subword_dict) tokenizer = lambda x: tokenize( x, non_lang_syms=subword_dict.non_lang_syms).split(' ') self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer)
def test_speech_tokenizer(self): for i, sent in enumerate(self.text): print('test sentence {}:'.format(i)) print(sent) tokens = utils.tokenize(sent, \ space=self.dict.space_word, non_lang_syms=self.non_lang_syms) # test :func:`~speech_tools.utils.tokenize` with # :func:`~TokenDictionary.encode_line` tensor = self.dict.encode_line(tokens, add_if_not_exist=False, append_eos=True) reconstructed_tokens = self.dict.string(tensor) expected_tokens = ' '.join( [token if self.dict.index(token) != self.dict.unk() else \ self.dict.unk_word for token in tokens.split(' ')] ) self.assertEqual(reconstructed_tokens, expected_tokens) # test :func:`~speech_tools.utils.tokenize` with # :func:`~TokenDictionary.tokens_to_sentence` reconstructed_sent = self.dict.tokens_to_sentence(tokens) expected_sent = [] words = sent.split(' ') for w in words: if w not in self.non_lang_syms: new_word = ''.join([ self.dict.unk_word if c in self.oovs else c for c in w ]) expected_sent.append(new_word) else: expected_sent.append(w) expected_sent = ' '.join(expected_sent) self.assertEqual(reconstructed_sent, expected_sent)
def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): super().__init__(wordlm.decoder.dictionary) assert isinstance(wordlm, FairseqLanguageModel) self.lm_decoder = wordlm.decoder assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \ callable(self.lm_decoder.masked_copy_incremental_state), \ 'The wrapped decoder should implement masked_copy_incremental_state()' self.oov_penalty = oov_penalty self.open_vocab = open_vocab self.zero = 1e-10 # a sufficiently small value to avoid the log(0) issue word_dict = self.lm_decoder.dictionary assert isinstance(word_dict, TokenDictionary) self.word_pad_idx = word_dict.pad() self.word_eos_idx = word_dict.eos() self.word_unk_idx = word_dict.unk() assert isinstance(subword_dict, TokenDictionary) self.subword_space_idx = subword_dict.space() self.subword_pad_idx = subword_dict.pad() self.subword_eos_idx = subword_dict.eos() self.subword_vocab_size = len(subword_dict) tokenizer = lambda x: tokenize( x, non_lang_syms=subword_dict.non_lang_syms).split(' ') self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer) def max_out_degree(node): if len(node.children) == 0: return 0 cur_max = len(node.children) for _, node in node.children.items(): cur_max = max(cur_max, max_out_degree(node)) return cur_max self.max_num_children = max_out_degree(self.lexroot) assert self.max_num_children <= self.subword_vocab_size