def _new_trainer(): transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform) test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform) multi_module = _MultiModelSupervisedLearningModule( nn.CrossEntropyLoss, {'acc': pl._AccuracyWithLogits}) lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True, max_epochs=1, limit_train_batches=0.25, enable_progress_bar=False), train_dataloader=pl.DataLoader(train_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100)) return lightning
def test_multi_model_trainer_gpu(self): _reset() if not (torch.cuda.is_available() and torch.cuda.device_count() >= 2): pytest.skip('test requires GPU and torch+cuda') transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform) test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform) multi_module = _MultiModelSupervisedLearningModule( nn.CrossEntropyLoss, {'acc': pl._AccuracyWithLogits}, n_models=2) lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True, max_epochs=1, limit_train_batches=0.25), train_dataloader=pl.DataLoader( train_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100)) lightning._execute(_model_gpu) result = _get_final_result() assert len(result) == 2 for _ in result: assert _ > 0.8