def collate_fn(batch, train=True): """ list of tensors to a batch tensors """ premise_batch, _ = pad_batch([row['premise'] for row in batch]) hypothesis_batch, _ = pad_batch([row['hypothesis'] for row in batch]) label_batch = [row['label'] for row in batch] # PyTorch RNN requires batches to be transposed for speed and integration with CUDA transpose = (lambda b: torch.stack(b).t_().squeeze(0).contiguous()) return (transpose(premise_batch), transpose(hypothesis_batch), transpose(label_batch))
def collate_fn(batch): text_batch, _ = pad_batch([row['text'] for row in batch]) label_batch = [row['label'] for row in batch] to_tensor = (lambda b: torch.stack(b).squeeze(-1)) return text_batch, to_tensor(label_batch)
def test_pad_batch(): batch = [ torch.LongTensor([1, 2, 3]), torch.LongTensor([1, 2]), torch.LongTensor([1]) ] padded, lengths = pad_batch(batch, PADDING_INDEX) padded = [r.tolist() for r in padded] assert padded == [[1, 2, 3], [1, 2, PADDING_INDEX], [1, PADDING_INDEX, PADDING_INDEX]] assert lengths == [3, 2, 1]
def collate_fn(batch): """ list of tensors to a batch variable """ # PyTorch RNN requires sorting decreasing size batch = sorted(batch, key=lambda row: len(row[input_key]), reverse=True) input_batch, input_lengths = pad_batch( [row[input_key] for row in batch]) output_batch, output_lengths = pad_batch( [row[output_key] for row in batch]) def batch_to_variable(batch): # PyTorch RNN requires batches to be transposed for speed and integration with CUDA return Variable(torch.stack(batch).t_().contiguous(), volatile=not train) # Async minibatch allocation for speed # Reference: http://timdettmers.com/2015/03/09/deep-learning-hardware-guide/ cuda = lambda t: t.cuda(async=True) if torch.cuda.is_available() else t return (cuda(batch_to_variable(input_batch)), cuda(torch.LongTensor(input_lengths)), cuda(batch_to_variable(output_batch)), cuda(torch.LongTensor(output_lengths)))
def collate_fn(batch): """ list of tensors to a batch tensors """ text_batch, _ = pad_batch([row['text'] for row in batch]) label_batch = [row['label'] for row in batch] return [text_batch, torch.cat(label_batch)]