def __init__(self,
                 word_lm: FairseqLanguageModel,
                 subword_dict: TokenDictionary,
                 oov_penalty: float = 1e-4,
                 open_vocab: bool = True):
        super().__init__(word_lm.decoder.dictionary)

        self.lm_decoder: FairseqIncrementalDecoder = word_lm.decoder
        assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \
            callable(self.lm_decoder.masked_copy_incremental_state), \
            'The wrapped decoder should implement masked_copy_incremental_state()'

        self.oov_penalty = oov_penalty
        self.open_vocab = open_vocab
        self.zero = 1e-10  # a sufficiently small value to avoid the log(0) issue

        word_dict: TokenDictionary = self.lm_decoder.dictionary
        self.word_pad_idx = word_dict.pad()
        self.word_eos_idx = word_dict.eos()
        self.word_unk_idx = word_dict.unk()

        self.subword_space_idx = subword_dict.space()
        self.subword_pad_idx = subword_dict.pad()
        self.subword_eos_idx = subword_dict.eos()
        self.subword_vocab_size = len(subword_dict)

        tokenizer: Callable[[str], List[str]] = \
            lambda x: tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ')
        self.tree = TensorizedPrefixTree.build(word_dict, subword_dict,
                                               tokenizer)

        assert self.tree.max_out_degree() <= self.subword_vocab_size
示例#2
0
    def __init__(self,
                 wordlm,
                 subwordlm,
                 subwordlm_weight=0.8,
                 oov_penalty=1.0,
                 open_vocab=True):
        super().__init__(wordlm.decoder.dictionary)

        assert isinstance(wordlm, FairseqLanguageModel)
        self.wordlm_decoder = wordlm.decoder
        assert hasattr(self.wordlm_decoder, 'masked_copy_incremental_state') and \
            callable(self.wordlm_decoder.masked_copy_incremental_state), \
            'The wrapped decoder should implement masked_copy_incremental_state()'
        assert isinstance(subwordlm, FairseqLanguageModel)
        self.subwordlm_decoder = subwordlm.decoder
        self.subwordlm_weight = subwordlm_weight
        self.log_oov_penalty = math.log(oov_penalty)
        self.open_vocab = open_vocab
        self.logzero = -10.0

        word_dict = self.wordlm_decoder.dictionary
        assert isinstance(word_dict, TokenDictionary)
        self.word_eos_idx = word_dict.eos()
        self.word_unk_idx = word_dict.unk()

        subword_dict = self.subwordlm_decoder.dictionary
        assert isinstance(subword_dict, TokenDictionary)
        self.subword_space_idx = subword_dict.space()
        self.subword_eos_idx = subword_dict.eos()
        self.subword_vocab_size = len(subword_dict)

        tokenizer = lambda x: tokenize(
            x, non_lang_syms=subword_dict.non_lang_syms).split(' ')
        self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer)
示例#3
0
    def test_speech_tokenizer(self):
        for i, sent in enumerate(self.text):
            print('test sentence {}:'.format(i))
            print(sent)
            tokens = utils.tokenize(sent, \
                space=self.dict.space_word, non_lang_syms=self.non_lang_syms)

            # test :func:`~speech_tools.utils.tokenize` with
            # :func:`~TokenDictionary.encode_line`
            tensor = self.dict.encode_line(tokens,
                                           add_if_not_exist=False,
                                           append_eos=True)
            reconstructed_tokens = self.dict.string(tensor)
            expected_tokens = ' '.join(
                [token if self.dict.index(token) != self.dict.unk() else \
                    self.dict.unk_word for token in tokens.split(' ')]
            )
            self.assertEqual(reconstructed_tokens, expected_tokens)

            # test :func:`~speech_tools.utils.tokenize` with
            # :func:`~TokenDictionary.tokens_to_sentence`
            reconstructed_sent = self.dict.tokens_to_sentence(tokens)
            expected_sent = []
            words = sent.split(' ')
            for w in words:
                if w not in self.non_lang_syms:
                    new_word = ''.join([
                        self.dict.unk_word if c in self.oovs else c for c in w
                    ])
                    expected_sent.append(new_word)
                else:
                    expected_sent.append(w)
            expected_sent = ' '.join(expected_sent)
            self.assertEqual(reconstructed_sent, expected_sent)
示例#4
0
    def __init__(self,
                 wordlm,
                 subword_dict,
                 oov_penalty=1e-4,
                 open_vocab=True):
        super().__init__(wordlm.decoder.dictionary)

        assert isinstance(wordlm, FairseqLanguageModel)
        self.lm_decoder = wordlm.decoder
        assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \
            callable(self.lm_decoder.masked_copy_incremental_state), \
            'The wrapped decoder should implement masked_copy_incremental_state()'
        self.oov_penalty = oov_penalty
        self.open_vocab = open_vocab
        self.zero = 1e-10  # a sufficiently small value to avoid the log(0) issue

        word_dict = self.lm_decoder.dictionary
        assert isinstance(word_dict, TokenDictionary)
        self.word_pad_idx = word_dict.pad()
        self.word_eos_idx = word_dict.eos()
        self.word_unk_idx = word_dict.unk()

        assert isinstance(subword_dict, TokenDictionary)
        self.subword_space_idx = subword_dict.space()
        self.subword_pad_idx = subword_dict.pad()
        self.subword_eos_idx = subword_dict.eos()
        self.subword_vocab_size = len(subword_dict)

        tokenizer = lambda x: tokenize(
            x, non_lang_syms=subword_dict.non_lang_syms).split(' ')
        self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer)

        def max_out_degree(node):
            if len(node.children) == 0:
                return 0
            cur_max = len(node.children)
            for _, node in node.children.items():
                cur_max = max(cur_max, max_out_degree(node))
            return cur_max

        self.max_num_children = max_out_degree(self.lexroot)
        assert self.max_num_children <= self.subword_vocab_size