def test_test(tmpdir): """Tests that the model can be tested on our ``DummyDataset``.""" model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) test_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.test(model, test_dl)
def test_test(tmpdir): """Tests that the model can be tested on a pytorch geometric dataset.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess()) test_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.test(model, test_dl)
def test_transformations(tmpdir): transform = TestInputTransform() datamodule = DataModule( TestInput(RunningStage.TRAINING, [1]), TestInput(RunningStage.VALIDATING, [1]), TestInput(RunningStage.TESTING, [1]), transform=transform, batch_size=2, num_workers=0, ) assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3) batch = next(iter(datamodule.train_dataloader())) assert torch.equal(batch, torch.tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]])) assert datamodule.val_dataloader().dataset[0] == {"a": 0, "b": 1} assert datamodule.val_dataloader().dataset[1] == {"a": 1, "b": 2} batch = next(iter(datamodule.val_dataloader())) datamodule = DataModule( TestInput(RunningStage.TRAINING, [1]), TestInput(RunningStage.VALIDATING, [1]), TestInput(RunningStage.TESTING, [1]), transform=TestInputTransform2, batch_size=2, num_workers=0, ) batch = next(iter(datamodule.val_dataloader())) assert torch.equal(batch["a"], torch.tensor([0, 1])) assert torch.equal(batch["b"], torch.tensor([1, 2])) model = CustomModel() trainer = Trainer( max_epochs=1, limit_train_batches=2, limit_val_batches=1, limit_test_batches=2, limit_predict_batches=2, num_sanity_val_steps=1, ) trainer.fit(model, datamodule=datamodule) trainer.test(model, datamodule=datamodule) assert datamodule.input_transform.train_per_sample_transform_called assert datamodule.input_transform.train_collate_called assert datamodule.input_transform.train_per_batch_transform_on_device_called assert datamodule.input_transform.train_per_sample_transform_called assert datamodule.input_transform.val_collate_called assert datamodule.input_transform.val_per_batch_transform_on_device_called assert datamodule.input_transform.test_per_sample_transform_called
def test_init_train_enable_ort(tmpdir): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: assert isinstance(pl_module.model, ORTModule) model = TextClassifier(2, TEST_BACKBONE, enable_ort=True) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TestCallback()) trainer.fit( model, train_dataloader=torch.utils.data.DataLoader(DummyDataset()), val_dataloaders=torch.utils.data.DataLoader(DummyDataset()), ) trainer.test(model, test_dataloaders=torch.utils.data.DataLoader(DummyDataset()))
def test_not_trainable(tmpdir): """Tests that the model gives an error when training, validating, or testing.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphEmbedder( GraphClassifier(num_features=1, num_classes=1).backbone) datamodule = DataModule( GraphClassificationDatasetInput(RunningStage.TRAINING, tudataset), GraphClassificationDatasetInput(RunningStage.VALIDATING, tudataset), GraphClassificationDatasetInput(RunningStage.TESTING, tudataset), transform=GraphClassificationInputTransform, batch_size=4, ) trainer = Trainer(default_root_dir=tmpdir, num_sanity_val_steps=0) with pytest.raises(NotImplementedError, match="Training a `GraphEmbedder` is not supported."): trainer.fit(model, datamodule=datamodule) with pytest.raises(NotImplementedError, match="Validating a `GraphEmbedder` is not supported."): trainer.validate(model, datamodule=datamodule) with pytest.raises(NotImplementedError, match="Testing a `GraphEmbedder` is not supported."): trainer.test(model, datamodule=datamodule)
def test_data_module(): seed_everything(42) def train_fn(data): return data - 100 def val_fn(data): return data + 100 def test_fn(data): return data - 1000 def predict_fn(data): return data + 1000 @dataclass class TestTransform(InputTransform): def per_sample_transform(self): def fn(x): return x return fn def train_per_batch_transform_on_device(self) -> Callable: return train_fn def val_per_batch_transform_on_device(self) -> Callable: return val_fn def test_per_batch_transform_on_device(self) -> Callable: return test_fn def predict_per_batch_transform_on_device(self) -> Callable: return predict_fn transform = TestTransform() assert transform._transform is not None train_dataset = Input(RunningStage.TRAINING, np.arange(10, dtype=np.float32)) assert train_dataset.running_stage == RunningStage.TRAINING val_dataset = Input(RunningStage.VALIDATING, np.arange(10, dtype=np.float32)) assert val_dataset.running_stage == RunningStage.VALIDATING test_dataset = Input(RunningStage.TESTING, np.arange(10, dtype=np.float32)) assert test_dataset.running_stage == RunningStage.TESTING predict_dataset = Input(RunningStage.PREDICTING, np.arange(10, dtype=np.float32)) assert predict_dataset.running_stage == RunningStage.PREDICTING dm = DataModule( train_input=train_dataset, val_input=val_dataset, test_input=test_dataset, predict_input=predict_dataset, transform=transform, batch_size=2, ) assert len(dm.train_dataloader()) == 5 batch = next(iter(dm.train_dataloader())) assert batch.shape == torch.Size([2]) assert batch.min() >= 0 and batch.max() < 10 assert len(dm.val_dataloader()) == 5 batch = next(iter(dm.val_dataloader())) assert batch.shape == torch.Size([2]) assert batch.min() >= 0 and batch.max() < 10 class TestModel(Task): def training_step(self, batch, batch_idx): assert sum(batch < 0) == 2 def validation_step(self, batch, batch_idx): assert sum(batch > 0) == 2 def test_step(self, batch, batch_idx): assert sum(batch < 500) == 2 def predict_step(self, batch, *args, **kwargs): assert sum(batch > 500) == 2 assert torch.equal(batch, torch.tensor([1000.0, 1001.0])) def on_train_dataloader(self) -> None: pass def on_val_dataloader(self) -> None: pass def on_test_dataloader(self, *_) -> None: pass def on_predict_dataloader(self) -> None: pass def on_predict_end(self) -> None: pass def on_fit_end(self) -> None: pass model = TestModel(torch.nn.Linear(1, 1)) trainer = Trainer(fast_dev_run=True) trainer.fit(model, datamodule=dm) trainer.validate(model, datamodule=dm) trainer.test(model, datamodule=dm) trainer.predict(model, datamodule=dm) # Test that plain lightning module works with FlashDataModule class SampleBoringModel(BoringModel): def __init__(self): super().__init__() self.layer = torch.nn.Linear(2, 1) model = SampleBoringModel() trainer = Trainer(fast_dev_run=True) trainer.fit(model, datamodule=dm) trainer.validate(model, datamodule=dm) trainer.test(model, datamodule=dm) trainer.predict(model, datamodule=dm) transform = TestTransform() input = Input(RunningStage.TRAINING) dm = DataModule(train_input=input, batch_size=1, transform=transform) assert isinstance(dm.input_transform, TestTransform) class RandomDataset(Dataset): def __init__(self, size: int, length: int): self.len = length self.data = torch.ones(length, size) def __getitem__(self, index): return self.data[index] def __len__(self): return self.len def _add_hundred(x): if isinstance(x, Dict): x["input"] += 100 else: x += 100 return x class TrainInputTransform(InputTransform): def _add_one(self, x): if isinstance(x, Dict): x["input"] += 1 else: x += 1 return x def per_sample_transform(self) -> Callable: return self._add_one def val_per_sample_transform(self) -> Callable: return _add_hundred dm = DataModule( train_input=DatasetInput(RunningStage.TRAINING, RandomDataset(64, 32)), val_input=DatasetInput(RunningStage.VALIDATING, RandomDataset(64, 32)), test_input=DatasetInput(RunningStage.TESTING, RandomDataset(64, 32)), batch_size=3, transform=TrainInputTransform(), ) batch = next(iter(dm.train_dataloader())) assert batch["input"][0][0] == 2 batch = next(iter(dm.val_dataloader())) assert batch["input"][0][0] == 101 batch = next(iter(dm.test_dataloader())) assert batch["input"][0][0] == 2