def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len(batch) == 6 # sort sentences by lens for easy RNN operations lens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, lens) # sort words by lens for easy char-RNN operations batch_words = [w for sent in batch[1] for w in sent] word_lens = [len(x) for x in batch_words] batch_words, word_orig_idx = sort_all([batch_words], word_lens) batch_words = batch_words[0] word_lens = [len(x) for x in batch_words] # convert to tensors words = batch[0] words = get_long_tensor(words, batch_size) words_mask = torch.eq(words, PAD_ID) wordchars = get_long_tensor(batch_words, len(word_lens)) wordchars_mask = torch.eq(wordchars, PAD_ID) upos = get_long_tensor(batch[2], batch_size) xpos = get_long_tensor(batch[3], batch_size) ufeats = get_long_tensor(batch[4], batch_size) pretrained = get_long_tensor(batch[5], batch_size) sentlens = [len(x) for x in batch[0]] return words, words_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, orig_idx, word_orig_idx, sentlens, word_lens
def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len( batch ) == 3 # words: List[List[int]], chars: List[List[List[int]]], tags: List[List[int]] # sort sentences by lens for easy RNN operations sentlens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, sentlens) sentlens = [len(x) for x in batch[0]] # sort chars by lens for easy char-LM operations chars_forward, chars_backward, charoffsets_forward, charoffsets_backward, charlens = self.process_chars( batch[1]) chars_sorted, char_orig_idx = sort_all([ chars_forward, chars_backward, charoffsets_forward, charoffsets_backward ], charlens) chars_forward, chars_backward, charoffsets_forward, charoffsets_backward = chars_sorted charlens = [len(sent) for sent in chars_forward] # sort words by lens for easy char-RNN operations batch_words = [w for sent in batch[1] for w in sent] wordlens = [len(x) for x in batch_words] batch_words, word_orig_idx = sort_all([batch_words], wordlens) batch_words = batch_words[0] wordlens = [len(x) for x in batch_words] # convert to tensors words = get_long_tensor(batch[0], batch_size) words_mask = torch.eq(words, PAD_ID) wordchars = get_long_tensor(batch_words, len(wordlens)) wordchars_mask = torch.eq(wordchars, PAD_ID) chars_forward = get_long_tensor(chars_forward, batch_size, pad_id=self.vocab['char'].unit2id(' ')) chars_backward = get_long_tensor( chars_backward, batch_size, pad_id=self.vocab['char'].unit2id(' ')) chars = torch.cat([ chars_forward.unsqueeze(0), chars_backward.unsqueeze(0) ]) # padded forward and backward char idx charoffsets = [ charoffsets_forward, charoffsets_backward ] # idx for forward and backward lm to get word representation tags = get_long_tensor(batch[2], batch_size) return words, words_mask, wordchars, wordchars_mask, chars, tags, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets
def chunk_batches(self, data): res = [] if not self.eval: # sort sentences (roughly) by length for better memory utilization data = sorted(data, key=lambda x: len(x[0]), reverse=random.random() > .5) elif self.sort_during_eval: (data, ), self.data_orig_idx = sort_all([data], [len(x[0]) for x in data]) current = [] currentlen = 0 for x in data: if len(x[0]) + currentlen > self.batch_size: res.append(current) current = [] currentlen = 0 current.append(x) currentlen += len(x[0]) if currentlen > 0: res.append(current) return res
def build_char_reps(self, batch_chars, batch_offsets, device, forward=True): if forward: model = self.charmodel_forward vocab = self.charmodel_forward_vocab projection = self.charmodel_forward_projection else: model = self.charmodel_backward vocab = self.charmodel_backward_vocab projection = self.charmodel_backward_projection batch_charlens = [len(x) for x in batch_chars] chars_sorted, char_orig_idx = sort_all([batch_chars, batch_offsets], batch_charlens) batch_chars, batch_offsets = chars_sorted batch_charlens = [len(x) for x in batch_chars] chars = get_long_tensor(batch_chars, len(batch_chars), pad_id=vocab.unit2id(' ')).to(device=device) char_reps = model.get_representation(chars, batch_offsets, batch_charlens, char_orig_idx) char_reps = char_reps.data if projection is not None: char_reps = projection(char_reps) char_reps = torch.reshape(char_reps, [ max(len(x) for x in batch_offsets), len(batch_chars), char_reps.shape[-1] ]) char_reps = torch.transpose(char_reps, 0, 1) return char_reps
def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0 or key >= len(self.data): raise IndexError batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len(batch) == 5 # sort all fields by lens for easy RNN operations lens = [len(x) for x in batch[0]] batch, orig_idx = sort_all(batch, lens) # convert to tensors src = batch[0] src = get_long_tensor(src, batch_size) src_mask = torch.eq(src, constant.PAD_ID) tgt_in = get_long_tensor(batch[1], batch_size) tgt_out = get_long_tensor(batch[2], batch_size) pos = torch.LongTensor(batch[3]) edits = torch.LongTensor(batch[4]) assert tgt_in.size(1) == tgt_out.size( 1), "Target input and output sequence sizes do not match." return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx
def data_to_batches(data, batch_size, eval_mode, sort_during_eval, min_length_to_batch_separately): """ Given a list of lists, where the first element of each sublist represents the sentence, group the sentences into batches. During training mode (not eval_mode) the sentences are sorted by length with a bit of random shuffling. During eval mode, the sentences are sorted by length if sort_during_eval is true. Refactored from the data structure in case other models could use it and for ease of testing. Returns (batches, original_order), where original_order is None when in train mode or when unsorted and represents the original location of each sentence in the sort """ res = [] if not eval_mode: # sort sentences (roughly) by length for better memory utilization data = sorted(data, key=lambda x: len(x[0]), reverse=random.random() > .5) data_orig_idx = None elif sort_during_eval: (data, ), data_orig_idx = sort_all([data], [len(x[0]) for x in data]) else: data_orig_idx = None current = [] currentlen = 0 for x in data: if min_length_to_batch_separately is not None and len( x[0]) > min_length_to_batch_separately: if currentlen > 0: res.append(current) current = [] currentlen = 0 res.append([x]) else: if len(x[0]) + currentlen > batch_size and currentlen > 0: res.append(current) current = [] currentlen = 0 current.append(x) currentlen += len(x[0]) if currentlen > 0: res.append(current) return res, data_orig_idx