Exemple #1
0
def test_encoding():
    text = ["This is Stratford", "<pad>"]

    bpemb_en = BPEmb(lang="en", add_pad_emb=True)

    # We can auto-add and encode start/end tokens. However, encoder can't handle <pad> directly.
    # We should pad outside with the corresponding index (index of the last word when add_pad_emb True).
    print(bpemb_en.encode(text))
    print(bpemb_en.encode_with_eos(text))
    print(bpemb_en.encode_with_bos_eos(text))
    print(bpemb_en.encode_ids(text))
    print(bpemb_en.encode_ids_with_eos(text))
    print(bpemb_en.encode_ids_with_bos_eos(text))
class SubWordProcessor(DecodingCompatibleProcessorABC):
    def __init__(self, bpe_info, padding_info):
        super().__init__()
        self._bpe_info = bpe_info
        self._padding_info = padding_info

        self._shared_bpe = None
        self._encoder_bpe = None
        self._decoder_bpe = None
        if "shared_bpe" in self._bpe_info:
            self._shared_bpe = BPEmb(**self._bpe_info["shared_bpe"])
            self._encoder_bpe = self._shared_bpe
            self._decoder_bpe = self._shared_bpe
        else:
            self._encoder_bpe = BPEmb(**self._bpe_info["encoder_bpe"])
            self._decoder_bpe = BPEmb(**self._bpe_info["decoder_bpe"])

    def process(self, data, **kwargs):
        # data: (encoder_input, decoder_input) which are both list_of_string (may be a list of list; check again though)
        # Assuming that text-based preprocesses are done before.
        # For encoder_input and decoder_input, I decided to add start/end tokens.
        # decoder_input and target should have same length after preprocessing.
        # Hence target will have one more pad element.

        encoder_input = self._encoder_bpe.encode_ids_with_bos_eos(data[0])
        decoder_input = self._decoder_bpe.encode_ids_with_bos_eos(data[1])
        target = self._decoder_bpe.encode_ids_with_eos(data[1])

        # bpe vocab-size does not account for pad word. Hence, weight matrix has length vocab-size + 1
        # As indices start from 0; pad index will be vocab-size.
        # Notice that if bpe is shared, pad token has the same index both for encoder and decoder..
        padded_enc_input = pad_sequences(
            encoder_input,
            maxlen=self._padding_info["enc_max_seq_len"],
            value=self._encoder_bpe.vocab_size,
            padding="post")
        padded_dec_input = pad_sequences(
            decoder_input,
            maxlen=self._padding_info["dec_max_seq_len"],
            value=self._decoder_bpe.vocab_size,
            padding="post")
        padded_target = pad_sequences(
            target,
            maxlen=self._padding_info["dec_max_seq_len"],
            value=self._decoder_bpe.vocab_size,
            padding="post")

        return [padded_enc_input, padded_dec_input], padded_target

    def encode(self, data, usage="encoder", **kwargs):
        # data is a list of string (may be a list of list)
        cur_bpe = self._encoder_bpe
        max_seq_len = self._padding_info["enc_max_seq_len"]
        pad_value = self._encoder_bpe.vocab_size
        if usage != "encoder":
            cur_bpe = self._decoder_bpe
            max_seq_len = self._padding_info["dec_max_seq_len"]
            pad_value = self._decoder_bpe.vocab_size

        encoded = cur_bpe.encode_ids_with_bos_eos(data)
        padded = pad_sequences(encoded,
                               maxlen=max_seq_len,
                               value=pad_value,
                               padding="post")

        return padded

    def decode(self, data, usage="decoder", **kwargs):
        # data is a list of ids (may be a list of list)
        # Designed for decoder id list to sentence mapping, but enabling for encoder as well.
        cur_bpe = self._decoder_bpe
        if usage != "decoder":
            cur_bpe = self._encoder_bpe

        # When decoding, bpe can't handle padding. Hence, we need to remove the padding first.
        pad_id = cur_bpe.vocab_size
        if any(isinstance(el, list) for el in data):
            pad_removed = []
            for elem in data:
                pad_removed.append(self.remove_padding(elem, pad_id))
            return cur_bpe.decode_ids(pad_removed)
        else:
            return cur_bpe.decode_ids(self.remove_padding(data, pad_id))

    def get_tag_ids(self, usage="decoder", **kwargs):
        # Specifically; start, end and pad tag ids of decoder
        # Re-consider unknown~
        cur_bpe = self._decoder_bpe
        if usage != "decoder":
            cur_bpe = self._encoder_bpe

        # Since pad is the last element..
        tag_ids = {
            "start": cur_bpe.BOS,
            "end": cur_bpe.EOS,
            "pad": cur_bpe.vocab_size
        }
        return tag_ids

    def get_max_seq_length(self, usage="decoder", **kwargs):
        if usage == "decoder":
            return self._padding_info["dec_max_seq_len"]
        else:
            return self._padding_info["enc_max_seq_len"]

    @staticmethod
    def remove_padding(list_of_ids, pad_value):
        return [int(i) for i in list_of_ids if i != pad_value]