Example #1
0
    def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        assert len(batch) == 6

        # sort sentences by lens for easy RNN operations
        lens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, lens)

        # sort words by lens for easy char-RNN operations
        batch_words = [w for sent in batch[1] for w in sent]
        word_lens = [len(x) for x in batch_words]
        batch_words, word_orig_idx = sort_all([batch_words], word_lens)
        batch_words = batch_words[0]
        word_lens = [len(x) for x in batch_words]

        # convert to tensors
        words = batch[0]
        words = get_long_tensor(words, batch_size)
        words_mask = torch.eq(words, PAD_ID)
        wordchars = get_long_tensor(batch_words, len(word_lens))
        wordchars_mask = torch.eq(wordchars, PAD_ID)

        upos = get_long_tensor(batch[2], batch_size)
        xpos = get_long_tensor(batch[3], batch_size)
        ufeats = get_long_tensor(batch[4], batch_size)
        pretrained = get_long_tensor(batch[5], batch_size)
        sentlens = [len(x) for x in batch[0]]
        return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, orig_idx, word_orig_idx, sentlens, word_lens
Example #2
0
    def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        assert len(
            batch
        ) == 3  # words: List[List[int]], chars: List[List[List[int]]], tags: List[List[int]]

        # sort sentences by lens for easy RNN operations
        sentlens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, sentlens)
        sentlens = [len(x) for x in batch[0]]

        # sort chars by lens for easy char-LM operations
        chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens = self.process_chars(
            batch[1])
        chars_sorted, char_orig_idx = sort_all([
            chars_forward, chars_backward, charoffsets_forward,
            charoffsets_backward
        ], charlens)
        chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = chars_sorted
        charlens = [len(sent) for sent in chars_forward]

        # sort words by lens for easy char-RNN operations
        batch_words = [w for sent in batch[1] for w in sent]
        wordlens = [len(x) for x in batch_words]
        batch_words, word_orig_idx = sort_all([batch_words], wordlens)
        batch_words = batch_words[0]
        wordlens = [len(x) for x in batch_words]

        # convert to tensors
        words = get_long_tensor(batch[0], batch_size)
        words_mask = torch.eq(words, PAD_ID)
        wordchars = get_long_tensor(batch_words, len(wordlens))
        wordchars_mask = torch.eq(wordchars, PAD_ID)
        chars_forward = get_long_tensor(chars_forward,
                                        batch_size,
                                        pad_id=self.vocab['char'].unit2id(' '))
        chars_backward = get_long_tensor(
            chars_backward, batch_size, pad_id=self.vocab['char'].unit2id(' '))
        chars = torch.cat([
            chars_forward.unsqueeze(0),
            chars_backward.unsqueeze(0)
        ])  # padded forward and backward char idx
        charoffsets = [
            charoffsets_forward, charoffsets_backward
        ]  # idx for forward and backward lm to get word representation
        tags = get_long_tensor(batch[2], batch_size)

        return words, words_mask, wordchars, wordchars_mask, chars, tags, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets
Example #3
0
    def chunk_batches(self, data):
        res = []

        if not self.eval:
            # sort sentences (roughly) by length for better memory utilization
            data = sorted(data,
                          key=lambda x: len(x[0]),
                          reverse=random.random() > .5)
        elif self.sort_during_eval:
            (data, ), self.data_orig_idx = sort_all([data],
                                                    [len(x[0]) for x in data])

        current = []
        currentlen = 0
        for x in data:
            if len(x[0]) + currentlen > self.batch_size:
                res.append(current)
                current = []
                currentlen = 0
            current.append(x)
            currentlen += len(x[0])

        if currentlen > 0:
            res.append(current)

        return res
Example #4
0
    def build_char_reps(self,
                        batch_chars,
                        batch_offsets,
                        device,
                        forward=True):
        if forward:
            model = self.charmodel_forward
            vocab = self.charmodel_forward_vocab
            projection = self.charmodel_forward_projection
        else:
            model = self.charmodel_backward
            vocab = self.charmodel_backward_vocab
            projection = self.charmodel_backward_projection

        batch_charlens = [len(x) for x in batch_chars]
        chars_sorted, char_orig_idx = sort_all([batch_chars, batch_offsets],
                                               batch_charlens)
        batch_chars, batch_offsets = chars_sorted
        batch_charlens = [len(x) for x in batch_chars]
        chars = get_long_tensor(batch_chars,
                                len(batch_chars),
                                pad_id=vocab.unit2id(' ')).to(device=device)
        char_reps = model.get_representation(chars, batch_offsets,
                                             batch_charlens, char_orig_idx)
        char_reps = char_reps.data
        if projection is not None:
            char_reps = projection(char_reps)
        char_reps = torch.reshape(char_reps, [
            max(len(x) for x in batch_offsets),
            len(batch_chars), char_reps.shape[-1]
        ])
        char_reps = torch.transpose(char_reps, 0, 1)

        return char_reps
Example #5
0
    def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        assert len(batch) == 5

        # sort all fields by lens for easy RNN operations
        lens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, lens)

        # convert to tensors
        src = batch[0]
        src = get_long_tensor(src, batch_size)
        src_mask = torch.eq(src, constant.PAD_ID)
        tgt_in = get_long_tensor(batch[1], batch_size)
        tgt_out = get_long_tensor(batch[2], batch_size)
        pos = torch.LongTensor(batch[3])
        edits = torch.LongTensor(batch[4])
        assert tgt_in.size(1) == tgt_out.size(
            1), "Target input and output sequence sizes do not match."
        return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx
Example #6
0
def data_to_batches(data, batch_size, eval_mode, sort_during_eval,
                    min_length_to_batch_separately):
    """
    Given a list of lists, where the first element of each sublist
    represents the sentence, group the sentences into batches.

    During training mode (not eval_mode) the sentences are sorted by
    length with a bit of random shuffling.  During eval mode, the
    sentences are sorted by length if sort_during_eval is true.

    Refactored from the data structure in case other models could use
    it and for ease of testing.

    Returns (batches, original_order), where original_order is None
    when in train mode or when unsorted and represents the original
    location of each sentence in the sort
    """
    res = []

    if not eval_mode:
        # sort sentences (roughly) by length for better memory utilization
        data = sorted(data,
                      key=lambda x: len(x[0]),
                      reverse=random.random() > .5)
        data_orig_idx = None
    elif sort_during_eval:
        (data, ), data_orig_idx = sort_all([data], [len(x[0]) for x in data])
    else:
        data_orig_idx = None

    current = []
    currentlen = 0
    for x in data:
        if min_length_to_batch_separately is not None and len(
                x[0]) > min_length_to_batch_separately:
            if currentlen > 0:
                res.append(current)
                current = []
                currentlen = 0
            res.append([x])
        else:
            if len(x[0]) + currentlen > batch_size and currentlen > 0:
                res.append(current)
                current = []
                currentlen = 0
            current.append(x)
            currentlen += len(x[0])

    if currentlen > 0:
        res.append(current)

    return res, data_orig_idx