def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key < 0: # or key % len(self.data): raise IndexError key = key % len(self.data) batch = self.data[key] batch_size = len(batch) batch = list(zip(*batch)) assert len(batch) == 4 # sort sentences by lens for easy RNN operations sentlens = [len(x) for x in batch[2]] batch, orig_idx = sort_all(batch, sentlens) sentlens = [len(x) for x in batch[2]] words = get_long_tensor(batch[2], batch_size) words_mask = torch.eq(words, PAD_ID) # convert to tensors tokens_phobert = batch[0] tokens_phobert = get_long_tensor(tokens_phobert, batch_size, pad_id=1) first_subword = batch[1] first_subword = get_long_tensor(first_subword, batch_size) tags = get_long_tensor(batch[3], batch_size) return tokens_phobert, first_subword, words_mask, tags, orig_idx, sentlens
def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key > 0 and key % len(self.data_dep) == 0: self.reshuffle() dep_key = key % len(self.data_dep) dep_batch = self.data_dep[dep_key] dep_batch_size = len(dep_batch) dep_batch = list(zip(*dep_batch)) assert len(dep_batch) == 6 dep_lens = [len(x) for x in dep_batch[2]] dep_batch, dep_orig_idx = sort_all(dep_batch, dep_lens) dep_batch_words = [w for sent in dep_batch[3] for w in sent] dep_word_lens = [len(x) for x in dep_batch_words] dep_batch_words, dep_word_orig_idx = sort_all([dep_batch_words], dep_word_lens) dep_batch_words = dep_batch_words[ 0] # [word1,...], word1 = list of tokens dep_word_lens = [len(x) for x in dep_batch_words] dep_wordchars = get_long_tensor(dep_batch_words, len(dep_word_lens)) dep_number_of_words = dep_wordchars.size(0) dep_words = dep_batch[2] dep_words = get_long_tensor(dep_words, dep_batch_size) dep_words_mask = torch.eq(dep_words, PAD_ID) # convert to tensors dep_tokens_phobert = dep_batch[0] dep_tokens_phobert = get_long_tensor(dep_tokens_phobert, dep_batch_size, pad_id=1) dep_first_subword = dep_batch[1] dep_first_subword = get_long_tensor(dep_first_subword, dep_batch_size) dep_sentlens = [len(x) for x in dep_batch[1]] dep_head = get_long_tensor(dep_batch[4], dep_batch_size) dep_deprel = get_long_tensor(dep_batch[5], dep_batch_size) dep_data = (dep_tokens_phobert, dep_first_subword, dep_words_mask, dep_head, dep_deprel, dep_number_of_words, dep_orig_idx, dep_sentlens) return dep_data
def chunk_batches(self, data): res = [] if not self.eval: # sort sentences (roughly) by length for better memory utilization data = sorted(data, key=lambda x: len(x[2]), reverse=random.random() > .5) elif self.sort_during_eval: (data, ), self.data_orig_idx_dep = sort_all( [data], [len(x[2]) for x in data]) current = [] for x in data: if len(current) >= self.batch_size: res.append(current) current = [] current.append(x) if len(current) > 0: res.append(current) return res ##list of list by sentences of list by token
def __getitem__(self, key): """ Get a batch with index. """ if not isinstance(key, int): raise TypeError if key > 0 and key % len(self.data) == 0: self.reshuffle() batch_key = key % len(self.data) batch = self.data[batch_key] batch_size = len(batch) batch = list(zip(*batch)) assert len(batch) == 3 lens = [len(x) for x in batch[2]] batch, orig_idx = sort_all(batch, lens) # convert to tensors tokens_phobert = batch[0] tokens_phobert = get_long_tensor(tokens_phobert, batch_size, pad_id=1) first_subword = batch[1] first_subword = get_long_tensor(first_subword, batch_size) upos = get_long_tensor(batch[2], batch_size) sentlens = [len(x) for x in batch[1]] return tokens_phobert, first_subword, upos, orig_idx, sentlens