示例#1
0
def collate_fn(batch, train=True):
    """ list of tensors to a batch tensors """
    premise_batch, _ = pad_batch([row['premise'] for row in batch])
    hypothesis_batch, _ = pad_batch([row['hypothesis'] for row in batch])
    label_batch = [row['label'] for row in batch]

    # PyTorch RNN requires batches to be transposed for speed and integration with CUDA
    transpose = (lambda b: torch.stack(b).t_().squeeze(0).contiguous())

    return (transpose(premise_batch), transpose(hypothesis_batch),
            transpose(label_batch))
示例#2
0
def collate_fn(batch):
    text_batch, _ = pad_batch([row['text'] for row in batch])
    label_batch = [row['label'] for row in batch]

    to_tensor = (lambda b: torch.stack(b).squeeze(-1))

    return text_batch, to_tensor(label_batch)
示例#3
0
def test_pad_batch():
    batch = [
        torch.LongTensor([1, 2, 3]),
        torch.LongTensor([1, 2]),
        torch.LongTensor([1])
    ]
    padded, lengths = pad_batch(batch, PADDING_INDEX)
    padded = [r.tolist() for r in padded]
    assert padded == [[1, 2, 3], [1, 2, PADDING_INDEX],
                      [1, PADDING_INDEX, PADDING_INDEX]]
    assert lengths == [3, 2, 1]
示例#4
0
    def collate_fn(batch):
        """ list of tensors to a batch variable """
        # PyTorch RNN requires sorting decreasing size
        batch = sorted(batch,
                       key=lambda row: len(row[input_key]),
                       reverse=True)
        input_batch, input_lengths = pad_batch(
            [row[input_key] for row in batch])
        output_batch, output_lengths = pad_batch(
            [row[output_key] for row in batch])

        def batch_to_variable(batch):
            # PyTorch RNN requires batches to be transposed for speed and integration with CUDA
            return Variable(torch.stack(batch).t_().contiguous(),
                            volatile=not train)

        # Async minibatch allocation for speed
        # Reference: http://timdettmers.com/2015/03/09/deep-learning-hardware-guide/
        cuda = lambda t: t.cuda(async=True) if torch.cuda.is_available() else t

        return (cuda(batch_to_variable(input_batch)),
                cuda(torch.LongTensor(input_lengths)),
                cuda(batch_to_variable(output_batch)),
                cuda(torch.LongTensor(output_lengths)))
示例#5
0
def collate_fn(batch):
    """ list of tensors to a batch tensors """
    text_batch, _ = pad_batch([row['text'] for row in batch])
    label_batch = [row['label'] for row in batch]
    return [text_batch, torch.cat(label_batch)]