示例#1
0
class NonAutoregressiveMachineTranslationAdapter(Adapter):
    __provider__ = 'narnmt'

    @classmethod
    def parameters(cls):
        parameters = super().parameters()
        parameters.update({
            'vocabulary_file':
            PathField(),
            'merges_file':
            PathField(),
            'output_name':
            StringField(optional=True, default=None),
            'sos_symbol':
            StringField(optional=True, default='<s>'),
            'eos_symbol':
            StringField(optional=True, default='</s>'),
            'pad_symbol':
            StringField(optional=True, default='<pad>'),
            'remove_extra_symbols':
            BoolField(optional=True, default=True)
        })
        return parameters

    def configure(self):
        if isinstance(SentencePieceBPETokenizer, UnsupportedPackage):
            SentencePieceBPETokenizer.raise_error(self.__provider__)
        self.tokenizer = SentencePieceBPETokenizer(
            str(self.get_value_from_config('vocabulary_file')),
            str(self.get_value_from_config('merges_file')))
        self.remove_extra_symbols = self.get_value_from_config(
            'remove_extra_symbols')
        self.idx = {}
        for s in ['sos', 'eos', 'pad']:
            self.idx[s] = str(self.get_value_from_config(s + '_symbol'))
        self.output_name = self.get_value_from_config('output_name')
        if self.output_name is None:
            self.output_name = self.output_blob

    def process(self, raw, identifiers, frame_meta):
        raw_outputs = self._extract_predictions(raw, frame_meta)
        translation = raw_outputs[self.output_name]
        results = []
        for identifier, tokens in zip(identifiers, translation):
            sentence = self.tokenizer.decode(tokens)
            if self.remove_extra_symbols:
                for s in self.idx.values():
                    sentence = sentence.replace(s, '')
            results.append(
                MachineTranslationPrediction(identifier,
                                             sentence.lstrip().split(' ')))
        return results
class TokenizerWrapper:
    def __init__(self, tok_type, unk_token, sep_token, cls_token, pad_token,
                 mask_token):
        self.tok_type = tok_type

        if self.tok_type == 'bpe':
            self.tokenizer = ByteLevelBPETokenizer()
        elif self.tok_type == 'wordpiece':
            self.tokenizer = BertWordPieceTokenizer(unk_token=unk_token,
                                                    sep_token=sep_token,
                                                    cls_token=cls_token,
                                                    pad_token=pad_token,
                                                    mask_token=mask_token)
        elif self.tok_type == 'sentencepiece':
            self.tokenizer = SentencePieceBPETokenizer(unk_token=unk_token)

    def train(self, data_file, vocab_size, special_tokens):
        if self.tok_type in ['bpe', 'wordpiece', 'sentencepiece']:
            self.tokenizer.train([data_file],
                                 vocab_size=vocab_size,
                                 special_tokens=special_tokens)

    def tokenize(self, text):
        if self.tok_type in ['bpe', 'wordpiece', 'sentencepiece']:
            return self.tokenizer.encode(text).tokens
        elif self.tok_type == 'word':
            return nltk.tokenize.word_tokenize(text)
        elif self.tok_type == 'char':
            return [c for c in text]
        else:
            raise Exception('Unknown tokenizer: ' + self.tok_type)

    def decode(self, tokens, blank_token):
        if self.tok_type in ['bpe', 'wordpiece', 'sentencepiece']:
            ids = [self.tokenizer.token_to_id(t) for t in tokens]
            ids = [
                i if i != None else self.tokenizer.token_to_id(blank_token)
                for i in ids
            ]
            return self.tokenizer.decode(ids, skip_special_tokens=False)
        elif self.tok_type == 'word':
            return ' '.join(tokens)
        elif self.tok_type == 'char':
            return ''.join(tokens)
        else:
            raise Exception('Unknown tokenizer: ' + self.tok_type)
class Tokenizer:
    """ Sentence tokenizer.

    Arguments:
        path (str): path to tokenizer's model folder.
        max_tokens (int): max tokens.
    """
    def __init__(self, path, max_tokens):
        self.logger = log.getLogger("Tokenizer")
        self.logger.info("loading tokenizer")
        self.logger.info("path: " + path)
        self.logger.info("max_tokens: " + str(max_tokens))
        self.tokenizer = SentencePieceBPETokenizer(
            os.path.join(path, "vocab.json"),
            os.path.join(path, "merges.txt")
        )
        self.max_tokens = max_tokens
        self.idx = {}
        for s in ['</s>', '<s>', '<pad>']:
            self.idx[s] = self.tokenizer.token_to_id(s)

    def encode(self, sentence):
        """ Encode method for sentence.

        Arguments:
            sentence (str): sentence.

        Returns:
            tokens (np.array): encoded sentence in tokneized format.
        """
        tokens = self.tokenizer.encode(sentence).ids
        return self._extend_tokens(tokens)

    def decode(self, tokens, remove_repeats=True):
        """ Decode method for tokens.

        Arguments:
            tokens (np.array): sentence in tokenized format.
            remove_repeats (bool): remove repeated words.

        Returns:
            sentence (str): output sentence.
        """
        sentence = self.tokenizer.decode(tokens)
        for s in self.idx.keys():
            sentence = sentence.replace(s, '')
        if remove_repeats:
            sentence = self._remove_repeats(sentence)
        return sentence.lstrip()

    def _extend_tokens(self, tokens):
        """ Extend tokens.

        Arguments:
            tokens (np.array): sentence in tokenized format.

        Returns:
            tokens (np.array): extended tokens.
        """
        tokens = [self.idx['<s>']] + tokens + [self.idx['</s>']]
        pad_length = self.max_tokens - len(tokens)
        if pad_length > 0:
            tokens = tokens + [self.idx['<pad>']] * pad_length
        return tokens

    def _remove_repeats(self, sentence):
        """ Remove repeated words.

        Arguments:
            sentence (str): sentence.

        Returns:
            sentence (str): sentence in lowercase without repeated words.
        """
        tokens = sentence.lower().split()
        return " ".join(key for key, _ in itertools.groupby(tokens))
class BPEVocabulary(Vocabulary):
    """ Represents a SentencePiece vocabulary for c2s.
    """
    def __init__(self, args: Namespace):
        super().__init__()

        self.target_encoder = SentencePieceBPETokenizer(
            args.target_vocab, args.target_merges)
        self.subtoken_encoder = SentencePieceBPETokenizer(
            args.subtoken_vocab, args.subtoken_merges)
        # self.target_encoder.add_special_tokens(
        #     [self.EOS_TOKEN, self.SOS_TOKEN, self.PAD_TOKEN]
        # )
        # self.subtoken_encoder.add_special_tokens([self.EOS_TOKEN, self.PAD_TOKEN])

        with open(args.node_dict, "rb") as f:
            self.node_to_index = pickle.load(f)
            self.index_to_node = {v: k for k, v in self.node_to_index.items()}

    def target_vocab_size(self):
        # print(self.target_encoder.num_special_tokens_to_add())
        return self.target_encoder.get_vocab_size() + 4

    def node_vocab_size(self):
        # print(self.target_encoder.num_special_tokens_to_add())
        return len(self.node_to_index) + 2

    def terminal_vocab_size(self):
        return self.subtoken_encoder.get_vocab_size() + 4

    def add_special_target_token(self, token: str):
        self.target_encoder.add_special_tokens([token])

    def add_special_terminal_token(self, token: str):
        self.subtoken_encoder.add_special_tokens([token])

    def encode_node(self, token_or_tokens):
        if isinstance(token_or_tokens, str):
            return self.node_to_index.get(token_or_tokens,
                                          self.node_to_index[self.UNK_TOKEN])
        else:
            return list(map(self.encode_node, token_or_tokens))

    def decode_node(self, index_or_indices):
        if isinstance(index_or_indices, int):
            return self.index_to_node[index_or_indices]
        else:
            return list(map(self.decode_node, index_or_indices))

    def encode_target(self, token_or_tokens):
        if isinstance(token_or_tokens, str):
            return self.target_encoder.token_to_id(token_or_tokens)
        else:
            return self.target_encoder.encode(" ".join(token_or_tokens)).ids

    def decode_target(self, index_or_indices):
        if isinstance(index_or_indices, int):
            return self.target_encoder.id_to_token(index_or_indices)
        else:
            return self.target_encoder.decode(index_or_indices)

    def encode_terminal(self, token_or_tokens):
        if isinstance(token_or_tokens, str):
            return self.subtoken_encoder.token_to_id(token_or_tokens)
        else:
            return self.subtoken_encoder.encode(" ".join(token_or_tokens)).ids

    def decode_terminal(self, index_or_indices):
        if isinstance(index_or_indices, int):
            return self.terminal_encoder.id_to_token(index_or_indices)
        else:
            return self.terminal_encoder.decode(index_or_indices)