def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len(batch) == 5 # sort all fields by lens for easy RNN operations lens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, lens) # convert to tensors src = batch[0] src = get_long_tensor(src, batch_size) src_mask = torch.eq(src, constant.PAD_ID) tgt_in = get_long_tensor(batch[1], batch_size) tgt_out = get_long_tensor(batch[2], batch_size) pos = torch.LongTensor(batch[3]) edits = torch.LongTensor(batch[4]) assert tgt_in.size(1) == tgt_out.size(1), "Target input and output sequence sizes do not match." return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx
def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len( batch ) == 3 # words: List[List[int]], chars: List[List[List[int]]], tags: List[List[int]] # sort sentences by lens for easy RNN operations sentlens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, sentlens) sentlens = [len(x) for x in batch[0]] # sort chars by lens for easy char-LM operations chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens = self.process_chars( batch[1]) chars_sorted, char_orig_idx = sort_all([ chars_forward, chars_backward, charoffsets_forward, charoffsets_backward ], charlens) chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = chars_sorted charlens = [len(sent) for sent in chars_forward] # sort words by lens for easy char-RNN operations batch_words = [w for sent in batch[1] for w in sent] wordlens = [len(x) for x in batch_words] batch_words, word_orig_idx = sort_all([batch_words], wordlens) batch_words = batch_words[0] wordlens = [len(x) for x in batch_words] # convert to tensors words = get_long_tensor(batch[0], batch_size) words_mask = torch.eq(words, PAD_ID) wordchars = get_long_tensor(batch_words, len(wordlens)) wordchars_mask = torch.eq(wordchars, PAD_ID) chars_forward = get_long_tensor(chars_forward, batch_size, pad_id=self.vocab['char'].unit2id(' ')) chars_backward = get_long_tensor( chars_backward, batch_size, pad_id=self.vocab['char'].unit2id(' ')) chars = torch.cat([ chars_forward.unsqueeze(0), chars_backward.unsqueeze(0) ]) # padded forward and backward char idx charoffsets = [ charoffsets_forward, charoffsets_backward ] # idx for forward and backward lm to get word representation tags = get_long_tensor(batch[2], batch_size) return words, words_mask, wordchars, wordchars_mask, chars, tags, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets
def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len(batch) == 6 # sort sentences by lens for easy RNN operations lens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, lens) # sort words by lens for easy char-RNN operations batch_words = [w for sent in batch[1] for w in sent] word_lens = [len(x) for x in batch_words] batch_words, word_orig_idx = sort_all([batch_words], word_lens) batch_words = batch_words[0] word_lens = [len(x) for x in batch_words] # convert to tensors words = batch[0] words = get_long_tensor(words, batch_size) words_mask = torch.eq(words, PAD_ID) wordchars = get_long_tensor(batch_words, len(word_lens)) wordchars_mask = torch.eq(wordchars, PAD_ID) upos = get_long_tensor(batch[2], batch_size) xpos = get_long_tensor(batch[3], batch_size) ufeats = get_long_tensor(batch[4], batch_size) pretrained = get_long_tensor(batch[5], batch_size) sentlens = [len(x) for x in batch[0]] return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, orig_idx, word_orig_idx, sentlens, word_lens
def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len(batch) == 7 # sort sentences by lens for easy RNN operations lens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, lens) # sort words by lens for easy char-RNN operations batch_words = [w for sent in batch[1] for w in sent] word_lens = [len(x) for x in batch_words] batch_words, word_orig_idx = sort_all([batch_words], word_lens) batch_words = batch_words[0] word_lens = [len(x) for x in batch_words] # convert to tensors words = batch[0] words = get_long_tensor(words, batch_size) words_mask = torch.eq(words, PAD_ID) wordchars = get_long_tensor(batch_words, len(word_lens)) wordchars_mask = torch.eq(wordchars, PAD_ID) upos = get_long_tensor(batch[2], batch_size) xpos = get_long_tensor(batch[3], batch_size) ufeats = get_long_tensor(batch[4], batch_size) pretrained = get_long_tensor(batch[5], batch_size) sentlens = [len(x) for x in batch[0]] lemma = get_long_tensor(batch[6], batch_size) next_word = words.clone() next_word[:, :-1] = words[:, 1:] next_word[:, -1] = PAD_ID prev_word = words.clone() prev_word[:, 1:] = words[:, :-1] prev_word[:, 0] = PAD_ID # if self.args['bptt'] is not None: # next_word = torch.tensor(self.f_nw[key]) # prev_word = torch.tensor(self.b_nw[key]) # b_batch = self.b_data[key] # b_batch_size = len(b_batch) # b_batch = list(zip(*b_batch)) # assert len(b_batch) == 7 # # sort sentences by lens for easy RNN operations # b_lens = [len(x) for x in b_batch[0]] # b_batch, orig_idx = sort_all(b_batch, b_lens) # # sort words by lens for easy char-RNN operations # b_batch_words = [w for sent in b_batch[1] for w in sent] # b_word_lens = [len(x) for x in b_batch_words] # b_batch_words, b_word_orig_idx = sort_all([b_batch_words], b_word_lens) # b_batch_words = b_batch_words[0] # b_word_lens = [len(x) for x in b_batch_words] # # convert to tensors # b_words = b_batch[0] # b_words = get_long_tensor(b_words, b_batch_size) # b_words_mask = torch.eq(b_words, PAD_ID) # b_wordchars = get_long_tensor(b_batch_words, len(word_lens)) # b_wordchars_mask = torch.eq(wordchars, PAD_ID) # b_upos = get_long_tensor(b_batch[2], b_batch_size) # b_xpos = get_long_tensor(b_batch[3], b_batch_size) # b_ufeats = get_long_tensor(b_batch[4], b_batch_size) # b_pretrained = get_long_tensor(b_batch[5], b_batch_size) # b_sentlens = [len(x) for x in b_batch[0]] # b_lemma = get_long_tensor(b_batch[6], b_batch_size) # return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained,\ # lemma, next_word, None, orig_idx, word_orig_idx, sentlens, word_lens,\ # (b_words, b_words_mask, b_wordchars, b_wordchars_mask, b_upos, b_xpos, b_ufeats, b_pretrained, # b_lemma, prev_word, None, b_orig_idx, b_word_orig_idx, b_sentlens, b_word_lens) return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, \ next_word, prev_word, orig_idx, word_orig_idx, sentlens, word_lens