def test_dm_transfer_batch_to_device(tmpdir):
    class CustomBatch:

        def __init__(self, data):
            self.samples = data[0]
            self.targets = data[1]

    class CurrentTestDM(LightningDataModule):

        hook_called = False

        def transfer_batch_to_device(self, data, device):
            self.hook_called = True
            if isinstance(data, CustomBatch):
                data.samples = data.samples.to(device)
                data.targets = data.targets.to(device)
            else:
                data = super().transfer_batch_to_device(data, device)
            return data

    model = EvalModelTemplate()
    dm = CurrentTestDM()
    batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))

    trainer = Trainer(gpus=1)
    # running .fit() would require us to implement custom data loaders, we mock the model reference instead
    trainer.get_model = MagicMock(return_value=model)
    if is_overridden('transfer_batch_to_device', dm):
        model.transfer_batch_to_device = dm.transfer_batch_to_device

    trainer.accelerator_backend = GPUAccelerator(trainer)
    batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
    expected = torch.device('cuda', 0)
    assert dm.hook_called
    assert batch_gpu.samples.device == batch_gpu.targets.device == expected
예제 #2
0
def test_non_blocking():
    """ Tests that non_blocking=True only gets passed on torch.Tensor.to, but not on other objects. """
    trainer = Trainer()
    trainer.accelerator_backend = GPUAccelerator(trainer)

    batch = torch.zeros(2, 3)
    with patch.object(batch, 'to', wraps=batch.to) as mocked:
        batch = trainer.accelerator_backend.batch_to_device(
            batch, torch.device('cuda:0'))
        mocked.assert_called_with(torch.device('cuda', 0), non_blocking=True)

    class BatchObject(object):
        def to(self, *args, **kwargs):
            pass

    batch = BatchObject()
    with patch.object(batch, 'to', wraps=batch.to) as mocked:
        batch = trainer.accelerator_backend.batch_to_device(
            batch, torch.device('cuda:0'))
        mocked.assert_called_with(torch.device('cuda', 0))
예제 #3
0
def test_single_gpu_batch_parse():
    trainer = Trainer(gpus=1)
    trainer.accelerator_backend = GPUAccelerator(trainer)

    # non-transferrable types
    primitive_objects = [
        None, {}, [], 1.0, "x", [None, 2], {
            "x": (1, 2),
            "y": None
        }
    ]
    for batch in primitive_objects:
        data = trainer.accelerator_backend.batch_to_device(
            batch, torch.device('cuda:0'))
        assert data == batch

    # batch is just a tensor
    batch = torch.rand(2, 3)
    batch = trainer.accelerator_backend.batch_to_device(
        batch, torch.device('cuda:0'))
    assert batch.device.index == 0 and batch.type() == 'torch.cuda.FloatTensor'

    # tensor list
    batch = [torch.rand(2, 3), torch.rand(2, 3)]
    batch = trainer.accelerator_backend.batch_to_device(
        batch, torch.device('cuda:0'))
    assert batch[0].device.index == 0 and batch[0].type(
    ) == 'torch.cuda.FloatTensor'
    assert batch[1].device.index == 0 and batch[1].type(
    ) == 'torch.cuda.FloatTensor'

    # tensor list of lists
    batch = [[torch.rand(2, 3), torch.rand(2, 3)]]
    batch = trainer.accelerator_backend.batch_to_device(
        batch, torch.device('cuda:0'))
    assert batch[0][0].device.index == 0 and batch[0][0].type(
    ) == 'torch.cuda.FloatTensor'
    assert batch[0][1].device.index == 0 and batch[0][1].type(
    ) == 'torch.cuda.FloatTensor'

    # tensor dict
    batch = [{'a': torch.rand(2, 3), 'b': torch.rand(2, 3)}]
    batch = trainer.accelerator_backend.batch_to_device(
        batch, torch.device('cuda:0'))
    assert batch[0]['a'].device.index == 0 and batch[0]['a'].type(
    ) == 'torch.cuda.FloatTensor'
    assert batch[0]['b'].device.index == 0 and batch[0]['b'].type(
    ) == 'torch.cuda.FloatTensor'

    # tuple of tensor list and list of tensor dict
    batch = ([torch.rand(2, 3) for _ in range(2)], [{
        'a': torch.rand(2, 3),
        'b': torch.rand(2, 3)
    } for _ in range(2)])
    batch = trainer.accelerator_backend.batch_to_device(
        batch, torch.device('cuda:0'))
    assert batch[0][0].device.index == 0 and batch[0][0].type(
    ) == 'torch.cuda.FloatTensor'

    assert batch[1][0]['a'].device.index == 0
    assert batch[1][0]['a'].type() == 'torch.cuda.FloatTensor'

    assert batch[1][0]['b'].device.index == 0
    assert batch[1][0]['b'].type() == 'torch.cuda.FloatTensor'

    # namedtuple of tensor
    BatchType = namedtuple('BatchType', ['a', 'b'])
    batch = [
        BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)
    ]
    batch = trainer.accelerator_backend.batch_to_device(
        batch, torch.device('cuda:0'))
    assert batch[0].a.device.index == 0
    assert batch[0].a.type() == 'torch.cuda.FloatTensor'

    # non-Tensor that has `.to()` defined
    class CustomBatchType:
        def __init__(self):
            self.a = torch.rand(2, 2)

        def to(self, *args, **kwargs):
            self.a = self.a.to(*args, **kwargs)
            return self

    batch = trainer.accelerator_backend.batch_to_device(
        CustomBatchType(), torch.device('cuda:0'))
    assert batch.a.type() == 'torch.cuda.FloatTensor'

    # torchtext.data.Batch
    samples = [{
        'text': 'PyTorch Lightning is awesome!',
        'label': 0
    }, {
        'text': 'Please make it work with torchtext',
        'label': 1
    }]

    text_field = Field()
    label_field = LabelField()
    fields = {'text': ('text', text_field), 'label': ('label', label_field)}

    examples = [Example.fromdict(sample, fields) for sample in samples]
    dataset = Dataset(examples=examples, fields=fields.values())

    # Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first
    text_field.build_vocab(dataset)
    label_field.build_vocab(dataset)

    batch = Batch(data=examples, dataset=dataset)
    batch = trainer.accelerator_backend.batch_to_device(
        batch, torch.device('cuda:0'))

    assert batch.text.type() == 'torch.cuda.LongTensor'
    assert batch.label.type() == 'torch.cuda.LongTensor'