def test_adjoint_autograd(): """Compare ODE Adjoint vs Autograd gradients, s := [0, 1], adaptive-step""" d = ToyDataset() X, yn = d.generate(n_samples=512, dataset_type='moons', noise=.4) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") X_train = torch.Tensor(X).to(device) y_train = torch.LongTensor(yn.long()).to(device) train = data.TensorDataset(X_train, y_train) trainloader = data.DataLoader(train, batch_size=len(X), shuffle=False) f = nn.Sequential(nn.Linear(2, 64), nn.Tanh(), nn.Linear(64, 2)) model = NeuralODE(f, solver='dopri5', sensitivity='adjoint', atol=1e-5, rtol=1e-8).to(device) x, y = next(iter(trainloader)) # adjoint gradients y_hat = model(x) loss = nn.CrossEntropyLoss()(y_hat, y) loss.backward() adj_grad = torch.cat([p.grad.flatten() for p in model.parameters()]) # autograd gradients model.zero_grad() model.sensitivity = 'autograd' y_hat = model(x) loss = nn.CrossEntropyLoss()(y_hat, y) loss.backward() bp_grad = torch.cat([p.grad.flatten() for p in model.parameters()]) assert (torch.abs(bp_grad - adj_grad) <= 1e-3 ).all(), f'Gradient error: {torch.abs(bp_grad - adj_grad).sum()}'
def moons_dataloader(): d = ToyDataset() X, yn = d.generate(n_samples=512, dataset_type='moons', noise=.4) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") X_train = torch.Tensor(X).to(device) y_train = torch.LongTensor(yn.long()).to(device) train = data.TensorDataset(X_train, y_train) return X_train, data.DataLoader(train, batch_size=len(X), shuffle=False)
def moons_trainloader(): d = ToyDataset() X, yn = d.generate(n_samples=512, dataset_type='moons', noise=.4) X_train = torch.Tensor(X) y_train = torch.LongTensor(yn.long()) train = TensorDataset(X_train, y_train) trainloader = DataLoader(train, batch_size=len(X), shuffle=False) return trainloader
def test_stable_neural_de(testlearner): """Stable: basic functionality""" d = ToyDataset() X, yn = d.generate(n_samples=512, dataset_type='moons', noise=.4) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") X_train = torch.Tensor(X).to(device) y_train = torch.LongTensor(yn.long()).to(device) train = data.TensorDataset(X_train, y_train) trainloader = data.DataLoader(train, batch_size=len(X), shuffle=False) f = Stable(nn.Sequential(nn.Linear(2, 64), nn.Tanh(), nn.Linear(64, 1))) model = NeuralODE(f).to(device) learn = testlearner(model, trainloader=trainloader) trainer = pl.Trainer(min_epochs=10, max_epochs=30) trainer.fit(learn)
def test_stable_neural_de(testlearner): """Stable: basic functionality""" d = ToyDataset() X, yn = d.generate(n_samples=512, dataset_type='moons', noise=.4) X_train = torch.Tensor(X) y_train = torch.LongTensor(yn.long()) train = data.TensorDataset(X_train, y_train) trainloader = data.DataLoader(train, batch_size=len(X), shuffle=False) f = GNF(nn.Sequential(nn.Linear(2, 64), nn.Tanh(), nn.Linear(64, 1))) model = NeuralODE(f) t_span = torch.linspace(0, 1, 30) learn = testlearner(t_span, model, trainloader=trainloader) trainer = pl.Trainer(min_epochs=5, max_epochs=5) trainer.fit(learn)
def test_2nd_order(): """2nd order (MLP) Galerkin Neural ODE""" d = ToyDataset() X, yn = d.generate(n_samples=512, dataset_type='spirals', noise=.4) X_train = torch.Tensor(X).to(device) y_train = torch.LongTensor(yn.long()).to(device) train = data.TensorDataset(X_train, y_train) trainloader = data.DataLoader(train, batch_size=len(X), shuffle=False) f = nn.Sequential(DepthCat(1), nn.Linear(5, 64), nn.Tanh(), DepthCat(1), nn.Linear(65, 2)) model = nn.Sequential( Augmenter(augment_idx=1, augment_func=nn.Linear(2, 2)), NeuralDE(f, solver='dopri5', order=2)).to(device) learn = TestLearner(model, trainloader=trainloader) trainer = pl.Trainer(min_epochs=1, max_epochs=1) trainer.fit(learn)
def test_vanilla_galerkin(): """Vanilla Galerkin (MLP) Neural ODE""" d = ToyDataset() X, yn = d.generate(n_samples=512, dataset_type='spirals', noise=.4) X_train = torch.Tensor(X).to(device) y_train = torch.LongTensor(yn.long()).to(device) train = data.TensorDataset(X_train, y_train) trainloader = data.DataLoader(train, batch_size=len(X), shuffle=False) f = nn.Sequential(DepthCat(1), GalLinear(6, 64, basisfunc=Fourier(5)), nn.Tanh(), DepthCat(1), GalLinear(64, 6, basisfunc=Polynomial(2))) model = nn.Sequential( Augmenter(augment_idx=1, augment_func=nn.Linear(2, 4)), NeuralDE(f, solver='dopri5')).to(device) learn = TestLearner(model, trainloader=trainloader) trainer = pl.Trainer(min_epochs=1, max_epochs=1) trainer.fit(learn)
def test_augmented_data_control(): """Data-controlled NeuralDE with IL-Augmentation""" d = ToyDataset() X, yn = d.generate(n_samples=512, dataset_type='spirals', noise=.4) X_train = torch.Tensor(X).to(device) y_train = torch.LongTensor(yn.long()).to(device) train = data.TensorDataset(X_train, y_train) trainloader = data.DataLoader(train, batch_size=len(X), shuffle=False) f = nn.Sequential(DataControl(), nn.Linear(12, 64), nn.Tanh(), nn.Linear(64, 6)) model = nn.Sequential( Augmenter(augment_idx=1, augment_func=nn.Linear(2, 4)), NeuralDE(f, solver='dopri5')).to(device) learn = TestLearner(model, trainloader=trainloader) trainer = pl.Trainer(min_epochs=1, max_epochs=1) trainer.fit(learn)
def test_augmenter_func_is_trained(): """Test if augment function is trained without explicit definition""" d = ToyDataset() X, yn = d.generate(n_samples=512, dataset_type='spirals', noise=.4) X_train = torch.Tensor(X).to(device) y_train = torch.LongTensor(yn.long()).to(device) train = data.TensorDataset(X_train, y_train) trainloader = data.DataLoader(train, batch_size=len(X), shuffle=False) f = nn.Sequential(DataControl(), nn.Linear(12, 64), nn.Tanh(), nn.Linear(64, 6)) model = nn.Sequential( Augmenter(augment_idx=1, augment_func=nn.Linear(2, 4)), NeuralDE(f, solver='dopri5')).to(device) learn = TestLearner(model, trainloader=trainloader) trainer = pl.Trainer(min_epochs=1, max_epochs=1) p = torch.cat([p.flatten() for p in model[0].parameters()]) trainer.fit(learn) p_after = torch.cat([p.flatten() for p in model[0].parameters()]) assert (p != p_after).any()
def test_adjoint_autograd(): """Test generation of (vanilla) version of all static_datasets""" d = ToyDataset() for dataset_type in ['moons', 'spirals', 'spheres', 'gaussians', 'gaussians_spiral', 'diffeqml']: X, yn = d.generate(n_samples=512, noise=0.2, dataset_type=dataset_type)