Esempio n. 1
0
def test_test(tmpdir):
    """Tests that the model can be tested on our ``DummyDataset``."""
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features,
                                      num_classes=DummyDataset.num_classes)
    test_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.test(model, test_dl)
Esempio n. 2
0
def test_test(tmpdir):
    """Tests that the model can be tested on a pytorch geometric dataset."""
    tudataset = datasets.TUDataset(root=tmpdir, name="KKI")
    model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes)
    model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess())
    test_dl = torch.utils.data.DataLoader(tudataset, batch_size=4)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.test(model, test_dl)
Esempio n. 3
0
def test_transformations(tmpdir):

    transform = TestInputTransform()
    datamodule = DataModule(
        TestInput(RunningStage.TRAINING, [1]),
        TestInput(RunningStage.VALIDATING, [1]),
        TestInput(RunningStage.TESTING, [1]),
        transform=transform,
        batch_size=2,
        num_workers=0,
    )

    assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3)
    batch = next(iter(datamodule.train_dataloader()))
    assert torch.equal(batch, torch.tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]]))

    assert datamodule.val_dataloader().dataset[0] == {"a": 0, "b": 1}
    assert datamodule.val_dataloader().dataset[1] == {"a": 1, "b": 2}
    batch = next(iter(datamodule.val_dataloader()))

    datamodule = DataModule(
        TestInput(RunningStage.TRAINING, [1]),
        TestInput(RunningStage.VALIDATING, [1]),
        TestInput(RunningStage.TESTING, [1]),
        transform=TestInputTransform2,
        batch_size=2,
        num_workers=0,
    )
    batch = next(iter(datamodule.val_dataloader()))
    assert torch.equal(batch["a"], torch.tensor([0, 1]))
    assert torch.equal(batch["b"], torch.tensor([1, 2]))

    model = CustomModel()
    trainer = Trainer(
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=1,
        limit_test_batches=2,
        limit_predict_batches=2,
        num_sanity_val_steps=1,
    )
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)

    assert datamodule.input_transform.train_per_sample_transform_called
    assert datamodule.input_transform.train_collate_called
    assert datamodule.input_transform.train_per_batch_transform_on_device_called
    assert datamodule.input_transform.train_per_sample_transform_called
    assert datamodule.input_transform.val_collate_called
    assert datamodule.input_transform.val_per_batch_transform_on_device_called
    assert datamodule.input_transform.test_per_sample_transform_called
Esempio n. 4
0
def test_init_train_enable_ort(tmpdir):
    class TestCallback(Callback):
        def on_train_start(self, trainer: Trainer,
                           pl_module: LightningModule) -> None:
            assert isinstance(pl_module.model, ORTModule)

    model = TextClassifier(2, TEST_BACKBONE, enable_ort=True)
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      callbacks=TestCallback())
    trainer.fit(
        model,
        train_dataloader=torch.utils.data.DataLoader(DummyDataset()),
        val_dataloaders=torch.utils.data.DataLoader(DummyDataset()),
    )
    trainer.test(model,
                 test_dataloaders=torch.utils.data.DataLoader(DummyDataset()))
Esempio n. 5
0
def test_not_trainable(tmpdir):
    """Tests that the model gives an error when training, validating, or testing."""
    tudataset = datasets.TUDataset(root=tmpdir, name="KKI")
    model = GraphEmbedder(
        GraphClassifier(num_features=1, num_classes=1).backbone)
    datamodule = DataModule(
        GraphClassificationDatasetInput(RunningStage.TRAINING, tudataset),
        GraphClassificationDatasetInput(RunningStage.VALIDATING, tudataset),
        GraphClassificationDatasetInput(RunningStage.TESTING, tudataset),
        transform=GraphClassificationInputTransform,
        batch_size=4,
    )
    trainer = Trainer(default_root_dir=tmpdir, num_sanity_val_steps=0)
    with pytest.raises(NotImplementedError,
                       match="Training a `GraphEmbedder` is not supported."):
        trainer.fit(model, datamodule=datamodule)

    with pytest.raises(NotImplementedError,
                       match="Validating a `GraphEmbedder` is not supported."):
        trainer.validate(model, datamodule=datamodule)

    with pytest.raises(NotImplementedError,
                       match="Testing a `GraphEmbedder` is not supported."):
        trainer.test(model, datamodule=datamodule)
Esempio n. 6
0
def test_data_module():
    seed_everything(42)

    def train_fn(data):
        return data - 100

    def val_fn(data):
        return data + 100

    def test_fn(data):
        return data - 1000

    def predict_fn(data):
        return data + 1000

    @dataclass
    class TestTransform(InputTransform):
        def per_sample_transform(self):
            def fn(x):
                return x

            return fn

        def train_per_batch_transform_on_device(self) -> Callable:
            return train_fn

        def val_per_batch_transform_on_device(self) -> Callable:
            return val_fn

        def test_per_batch_transform_on_device(self) -> Callable:
            return test_fn

        def predict_per_batch_transform_on_device(self) -> Callable:
            return predict_fn

    transform = TestTransform()
    assert transform._transform is not None

    train_dataset = Input(RunningStage.TRAINING, np.arange(10,
                                                           dtype=np.float32))
    assert train_dataset.running_stage == RunningStage.TRAINING

    val_dataset = Input(RunningStage.VALIDATING, np.arange(10,
                                                           dtype=np.float32))
    assert val_dataset.running_stage == RunningStage.VALIDATING

    test_dataset = Input(RunningStage.TESTING, np.arange(10, dtype=np.float32))
    assert test_dataset.running_stage == RunningStage.TESTING

    predict_dataset = Input(RunningStage.PREDICTING,
                            np.arange(10, dtype=np.float32))
    assert predict_dataset.running_stage == RunningStage.PREDICTING

    dm = DataModule(
        train_input=train_dataset,
        val_input=val_dataset,
        test_input=test_dataset,
        predict_input=predict_dataset,
        transform=transform,
        batch_size=2,
    )
    assert len(dm.train_dataloader()) == 5
    batch = next(iter(dm.train_dataloader()))
    assert batch.shape == torch.Size([2])
    assert batch.min() >= 0 and batch.max() < 10

    assert len(dm.val_dataloader()) == 5
    batch = next(iter(dm.val_dataloader()))
    assert batch.shape == torch.Size([2])
    assert batch.min() >= 0 and batch.max() < 10

    class TestModel(Task):
        def training_step(self, batch, batch_idx):
            assert sum(batch < 0) == 2

        def validation_step(self, batch, batch_idx):
            assert sum(batch > 0) == 2

        def test_step(self, batch, batch_idx):
            assert sum(batch < 500) == 2

        def predict_step(self, batch, *args, **kwargs):
            assert sum(batch > 500) == 2
            assert torch.equal(batch, torch.tensor([1000.0, 1001.0]))

        def on_train_dataloader(self) -> None:
            pass

        def on_val_dataloader(self) -> None:
            pass

        def on_test_dataloader(self, *_) -> None:
            pass

        def on_predict_dataloader(self) -> None:
            pass

        def on_predict_end(self) -> None:
            pass

        def on_fit_end(self) -> None:
            pass

    model = TestModel(torch.nn.Linear(1, 1))
    trainer = Trainer(fast_dev_run=True)
    trainer.fit(model, datamodule=dm)
    trainer.validate(model, datamodule=dm)
    trainer.test(model, datamodule=dm)
    trainer.predict(model, datamodule=dm)

    # Test that plain lightning module works with FlashDataModule
    class SampleBoringModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(2, 1)

    model = SampleBoringModel()
    trainer = Trainer(fast_dev_run=True)
    trainer.fit(model, datamodule=dm)
    trainer.validate(model, datamodule=dm)
    trainer.test(model, datamodule=dm)
    trainer.predict(model, datamodule=dm)

    transform = TestTransform()
    input = Input(RunningStage.TRAINING)
    dm = DataModule(train_input=input, batch_size=1, transform=transform)
    assert isinstance(dm.input_transform, TestTransform)

    class RandomDataset(Dataset):
        def __init__(self, size: int, length: int):
            self.len = length
            self.data = torch.ones(length, size)

        def __getitem__(self, index):
            return self.data[index]

        def __len__(self):
            return self.len

    def _add_hundred(x):
        if isinstance(x, Dict):
            x["input"] += 100
        else:
            x += 100
        return x

    class TrainInputTransform(InputTransform):
        def _add_one(self, x):
            if isinstance(x, Dict):
                x["input"] += 1
            else:
                x += 1
            return x

        def per_sample_transform(self) -> Callable:
            return self._add_one

        def val_per_sample_transform(self) -> Callable:
            return _add_hundred

    dm = DataModule(
        train_input=DatasetInput(RunningStage.TRAINING, RandomDataset(64, 32)),
        val_input=DatasetInput(RunningStage.VALIDATING, RandomDataset(64, 32)),
        test_input=DatasetInput(RunningStage.TESTING, RandomDataset(64, 32)),
        batch_size=3,
        transform=TrainInputTransform(),
    )
    batch = next(iter(dm.train_dataloader()))
    assert batch["input"][0][0] == 2
    batch = next(iter(dm.val_dataloader()))
    assert batch["input"][0][0] == 101
    batch = next(iter(dm.test_dataloader()))
    assert batch["input"][0][0] == 2