Ejemplo n.º 1
0
    def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        assert len(batch) == 6

        # sort sentences by lens for easy RNN operations
        lens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, lens)

        # sort words by lens for easy char-RNN operations
        batch_words = [w for sent in batch[1] for w in sent]
        word_lens = [len(x) for x in batch_words]
        batch_words, word_orig_idx = sort_all([batch_words], word_lens)
        batch_words = batch_words[0]
        word_lens = [len(x) for x in batch_words]

        # convert to tensors
        words = batch[0]
        words = get_long_tensor(words, batch_size)
        words_mask = torch.eq(words, PAD_ID)
        wordchars = get_long_tensor(batch_words, len(word_lens))
        wordchars_mask = torch.eq(wordchars, PAD_ID)

        upos = get_long_tensor(batch[2], batch_size)
        xpos = get_long_tensor(batch[3], batch_size)
        ufeats = get_long_tensor(batch[4], batch_size)
        pretrained = get_long_tensor(batch[5], batch_size)
        sentlens = [len(x) for x in batch[0]]
        return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, orig_idx, word_orig_idx, sentlens, word_lens
Ejemplo n.º 2
0
    def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        assert len(
            batch
        ) == 3  # words: List[List[int]], chars: List[List[List[int]]], tags: List[List[int]]

        # sort sentences by lens for easy RNN operations
        sentlens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, sentlens)
        sentlens = [len(x) for x in batch[0]]

        # sort chars by lens for easy char-LM operations
        chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens = self.process_chars(
            batch[1])
        chars_sorted, char_orig_idx = sort_all([
            chars_forward, chars_backward, charoffsets_forward,
            charoffsets_backward
        ], charlens)
        chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = chars_sorted
        charlens = [len(sent) for sent in chars_forward]

        # sort words by lens for easy char-RNN operations
        batch_words = [w for sent in batch[1] for w in sent]
        wordlens = [len(x) for x in batch_words]
        batch_words, word_orig_idx = sort_all([batch_words], wordlens)
        batch_words = batch_words[0]
        wordlens = [len(x) for x in batch_words]

        # convert to tensors
        words = get_long_tensor(batch[0], batch_size)
        words_mask = torch.eq(words, PAD_ID)
        wordchars = get_long_tensor(batch_words, len(wordlens))
        wordchars_mask = torch.eq(wordchars, PAD_ID)
        chars_forward = get_long_tensor(chars_forward,
                                        batch_size,
                                        pad_id=self.vocab['char'].unit2id(' '))
        chars_backward = get_long_tensor(
            chars_backward, batch_size, pad_id=self.vocab['char'].unit2id(' '))
        chars = torch.cat([
            chars_forward.unsqueeze(0),
            chars_backward.unsqueeze(0)
        ])  # padded forward and backward char idx
        charoffsets = [
            charoffsets_forward, charoffsets_backward
        ]  # idx for forward and backward lm to get word representation
        tags = get_long_tensor(batch[2], batch_size)

        return words, words_mask, wordchars, wordchars_mask, chars, tags, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets
Ejemplo n.º 3
0
    def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        assert len(batch) == 5

        # sort all fields by lens for easy RNN operations
        lens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, lens)

        # convert to tensors
        src = batch[0]
        src = get_long_tensor(src, batch_size)
        src_mask = torch.eq(src, constant.PAD_ID)
        tgt_in = get_long_tensor(batch[1], batch_size)
        tgt_out = get_long_tensor(batch[2], batch_size)
        pos = torch.LongTensor(batch[3])
        edits = torch.LongTensor(batch[4])
        assert tgt_in.size(1) == tgt_out.size(1), "Target input and output sequence sizes do not match."
        return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx
Ejemplo n.º 4
0
    def chunk_batches(self, data):
        res = []

        if not self.eval:
            # sort sentences (roughly) by length for better memory utilization
            data = sorted(data,
                          key=lambda x: len(x[0]),
                          reverse=random.random() > .5)
        elif self.sort_during_eval:
            (data, ), self.data_orig_idx = sort_all([data],
                                                    [len(x[0]) for x in data])

        current = []
        currentlen = 0
        for x in data:
            if len(x[0]) + currentlen > self.batch_size:
                res.append(current)
                current = []
                currentlen = 0
            current.append(x)
            currentlen += len(x[0])

        if currentlen > 0:
            res.append(current)

        return res
Ejemplo n.º 5
0
    def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        assert len(batch) == 7

        # sort sentences by lens for easy RNN operations
        lens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, lens)

        # sort words by lens for easy char-RNN operations
        batch_words = [w for sent in batch[1] for w in sent]
        word_lens = [len(x) for x in batch_words]
        batch_words, word_orig_idx = sort_all([batch_words], word_lens)
        batch_words = batch_words[0]
        word_lens = [len(x) for x in batch_words]

        # convert to tensors
        words = batch[0]
        words = get_long_tensor(words, batch_size)
        words_mask = torch.eq(words, PAD_ID)
        wordchars = get_long_tensor(batch_words, len(word_lens))
        wordchars_mask = torch.eq(wordchars, PAD_ID)

        upos = get_long_tensor(batch[2], batch_size)
        xpos = get_long_tensor(batch[3], batch_size)
        ufeats = get_long_tensor(batch[4], batch_size)
        pretrained = get_long_tensor(batch[5], batch_size)
        sentlens = [len(x) for x in batch[0]]
        lemma = get_long_tensor(batch[6], batch_size)

        next_word = words.clone()
        next_word[:, :-1] = words[:, 1:]
        next_word[:, -1] = PAD_ID

        prev_word = words.clone()
        prev_word[:, 1:] = words[:, :-1]
        prev_word[:, 0] = PAD_ID

        # if self.args['bptt'] is not None:
        #     next_word = torch.tensor(self.f_nw[key])
        #     prev_word = torch.tensor(self.b_nw[key])

        #     b_batch = self.b_data[key]
        #     b_batch_size = len(b_batch)
        #     b_batch = list(zip(*b_batch))
        #     assert len(b_batch) == 7

        #     # sort sentences by lens for easy RNN operations
        #     b_lens = [len(x) for x in b_batch[0]]
        #     b_batch, orig_idx = sort_all(b_batch, b_lens)

        #     # sort words by lens for easy char-RNN operations
        #     b_batch_words = [w for sent in b_batch[1] for w in sent]
        #     b_word_lens = [len(x) for x in b_batch_words]
        #     b_batch_words, b_word_orig_idx = sort_all([b_batch_words], b_word_lens)
        #     b_batch_words = b_batch_words[0]
        #     b_word_lens = [len(x) for x in b_batch_words]

        #     # convert to tensors
        #     b_words = b_batch[0]
        #     b_words = get_long_tensor(b_words, b_batch_size)
        #     b_words_mask = torch.eq(b_words, PAD_ID)
        #     b_wordchars = get_long_tensor(b_batch_words, len(word_lens))
        #     b_wordchars_mask = torch.eq(wordchars, PAD_ID)

        #     b_upos = get_long_tensor(b_batch[2], b_batch_size)
        #     b_xpos = get_long_tensor(b_batch[3], b_batch_size)
        #     b_ufeats = get_long_tensor(b_batch[4], b_batch_size)
        #     b_pretrained = get_long_tensor(b_batch[5], b_batch_size)
        #     b_sentlens = [len(x) for x in b_batch[0]]
        #     b_lemma = get_long_tensor(b_batch[6], b_batch_size)
        #     return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained,\
        #         lemma, next_word, None, orig_idx, word_orig_idx, sentlens, word_lens,\
        #         (b_words, b_words_mask, b_wordchars, b_wordchars_mask, b_upos, b_xpos, b_ufeats, b_pretrained,
        #          b_lemma, prev_word, None, b_orig_idx, b_word_orig_idx, b_sentlens, b_word_lens)

        return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, \
            next_word, prev_word, orig_idx, word_orig_idx, sentlens, word_lens