예제 #1
0
    def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        assert len(batch) == 5

        # sort all fields by lens for easy RNN operations
        lens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, lens)

        # convert to tensors
        src = batch[0]
        src = get_long_tensor(src, batch_size)
        src_mask = torch.eq(src, constant.PAD_ID)
        tgt_in = get_long_tensor(batch[1], batch_size)
        tgt_out = get_long_tensor(batch[2], batch_size)
        pos = torch.LongTensor(batch[3])
        edits = torch.LongTensor(batch[4])
        assert tgt_in.size(1) == tgt_out.size(
            1), "Target input and output sequence sizes do not match."
        return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx
예제 #2
0
    def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        assert len(
            batch
        ) == 3  # words: List[List[int]], chars: List[List[List[int]]], tags: List[List[int]]

        # sort sentences by lens for easy RNN operations
        sentlens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, sentlens)
        sentlens = [len(x) for x in batch[0]]

        # sort chars by lens for easy char-LM operations
        chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens = self.process_chars(
            batch[1])
        chars_sorted, char_orig_idx = sort_all([
            chars_forward, chars_backward, charoffsets_forward,
            charoffsets_backward
        ], charlens)
        chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = chars_sorted
        charlens = [len(sent) for sent in chars_forward]

        # sort words by lens for easy char-RNN operations
        batch_words = [w for sent in batch[1] for w in sent]
        wordlens = [len(x) for x in batch_words]
        batch_words, word_orig_idx = sort_all([batch_words], wordlens)
        batch_words = batch_words[0]
        wordlens = [len(x) for x in batch_words]

        # convert to tensors
        words = get_long_tensor(batch[0], batch_size)
        words_mask = torch.eq(words, PAD_ID)
        wordchars = get_long_tensor(batch_words, len(wordlens))
        wordchars_mask = torch.eq(wordchars, PAD_ID)
        chars_forward = get_long_tensor(chars_forward,
                                        batch_size,
                                        pad_id=self.vocab['char'].unit2id(' '))
        chars_backward = get_long_tensor(
            chars_backward, batch_size, pad_id=self.vocab['char'].unit2id(' '))
        chars = torch.cat([
            chars_forward.unsqueeze(0),
            chars_backward.unsqueeze(0)
        ])  # padded forward and backward char idx
        charoffsets = [
            charoffsets_forward, charoffsets_backward
        ]  # idx for forward and backward lm to get word representation
        tags = get_long_tensor(batch[2], batch_size)

        return words, words_mask, wordchars, wordchars_mask, chars, tags, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets
예제 #3
0
    def build_char_reps(self,
                        batch_chars,
                        batch_offsets,
                        device,
                        forward=True):
        if forward:
            model = self.charmodel_forward
            vocab = self.charmodel_forward_vocab
            projection = self.charmodel_forward_projection
        else:
            model = self.charmodel_backward
            vocab = self.charmodel_backward_vocab
            projection = self.charmodel_backward_projection

        batch_charlens = [len(x) for x in batch_chars]
        chars_sorted, char_orig_idx = sort_all([batch_chars, batch_offsets],
                                               batch_charlens)
        batch_chars, batch_offsets = chars_sorted
        batch_charlens = [len(x) for x in batch_chars]
        chars = get_long_tensor(batch_chars,
                                len(batch_chars),
                                pad_id=vocab.unit2id(' ')).to(device=device)
        char_reps = model.get_representation(chars, batch_offsets,
                                             batch_charlens, char_orig_idx)
        char_reps = char_reps.data
        if projection is not None:
            char_reps = projection(char_reps)
        char_reps = torch.reshape(char_reps, [
            max(len(x) for x in batch_offsets),
            len(batch_chars), char_reps.shape[-1]
        ])
        char_reps = torch.transpose(char_reps, 0, 1)

        return char_reps
예제 #4
0
    def __getitem__(self, key):
        """ Get a batch with index. """
        if not isinstance(key, int):
            raise TypeError
        if key < 0 or key >= len(self.data):
            raise IndexError
        batch = self.data[key]
        batch_size = len(batch)
        batch = list(zip(*batch))
        assert len(batch) == 6

        # sort sentences by lens for easy RNN operations
        lens = [len(x) for x in batch[0]]
        batch, orig_idx = sort_all(batch, lens)

        # sort words by lens for easy char-RNN operations
        batch_words = [w for sent in batch[1] for w in sent]
        word_lens = [len(x) for x in batch_words]
        batch_words, word_orig_idx = sort_all([batch_words], word_lens)
        batch_words = batch_words[0]
        word_lens = [len(x) for x in batch_words]

        # convert to tensors
        words = batch[0]
        words = get_long_tensor(words, batch_size)
        words_mask = torch.eq(words, PAD_ID)
        wordchars = get_long_tensor(batch_words, len(word_lens))
        wordchars_mask = torch.eq(wordchars, PAD_ID)

        upos = get_long_tensor(batch[2], batch_size)
        xpos = get_long_tensor(batch[3], batch_size)
        ufeats = get_long_tensor(batch[4], batch_size)
        pretrained = get_long_tensor(batch[5], batch_size)
        sentlens = [len(x) for x in batch[0]]
        return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, orig_idx, word_orig_idx, sentlens, word_lens
예제 #5
0
    def build_char_representation(self, all_word_labels, device, forward):
        CHARLM_START = "\n"
        CHARLM_END = " "

        if forward:
            charlm = self.forward_charlm
            vocab = self.forward_charlm_vocab
        else:
            charlm = self.backward_charlm
            vocab = self.backward_charlm_vocab

        all_data = []
        for idx, word_labels in enumerate(all_word_labels):
            if forward:
                word_labels = reversed(word_labels)
            else:
                word_labels = [x[::-1] for x in word_labels]

            chars = [CHARLM_START]
            offsets = []
            for w in word_labels:
                chars.extend(w)
                chars.append(CHARLM_END)
                offsets.append(len(chars) - 1)
            if not forward:
                offsets.reverse()
            chars = vocab.map(chars)
            all_data.append((chars, offsets, len(chars), len(all_data)))

        all_data.sort(key=itemgetter(2), reverse=True)
        chars, char_offsets, char_lens, orig_idx = tuple(zip(*all_data))
        chars = get_long_tensor(chars,
                                len(all_data),
                                pad_id=vocab.unit2id(' ')).to(device=device)

        # TODO: surely this should be stuffed in the charlm model itself rather than done here
        with torch.no_grad():
            output, _, _ = charlm.forward(chars, char_lens)
            res = [
                output[i, offsets] for i, offsets in enumerate(char_offsets)
            ]
            res = unsort(res, orig_idx)

        return res