def test_ivp_legacy_signature(): with warns(FutureWarning): IVP(0, x_0=1) with warns(FutureWarning): IVP(0, 1, x_0_prime=2) with warns(FutureWarning): IVP(0, x_0=1, x_0_prime=2) with raises(KeyError): IVP(0, x_0=1, u_0=2) with raises(KeyError): IVP(0, x_0_prime=1, u_0_prime=2)
def test_ensemble_condition(x0, x1, y0, y1, ones, net12): cond = EnsembleCondition( IVP(x0, y0), IVP(x1, y0, y1), ) x = x0 * ones y = cond.enforce(net12, x) ya = y[:, 0:1] assert all_close(ya, y0), "y(x_0) != y_0" x = x1 * ones y = cond.enforce(net12, x) yb = y[:, 1:2] assert all_close(yb, y0), "y(x_0) != y_0" assert all_close(diff(yb, x), y1), "y'(x_0) != y'_0" net12 = FCNN(1, 1) cond = EnsembleCondition(IVP(x0, y0), ) x = x0 * ones y = cond.enforce(net12, x) assert all_close(y, y0), "y(x_0) != y_0"
def test_ivp(x0, y0, y1, ones, net11): x = x0 * ones cond = IVP(x0, y0) y = cond.enforce(net11, x) assert torch.isclose(y, y0 * ones).all(), "y(x_0) != y_0" cond = IVP(x0, y0, y1) y = cond.enforce(net11, x) assert all_close(y, y0), "y(x_0) != y_0" assert all_close(diff(y, x), y1), "y'(x_0) != y'_0"
def test_ivp(): x = x0 * ones net = FCNN(1, 1) cond = IVP(x0, y0) y = cond.enforce(net, x) assert torch.isclose(y, y0 * ones).all(), "y(x_0) != y_0" cond = IVP(x0, y0, y1) y = cond.enforce(net, x) assert all_close(y, y0), "y(x_0) != y_0" assert all_close(diff(y, x), y1), "y'(x_0) != y'_0"
def test_ibvp_1d(): t0, t1 = random.random(), random.random() u00, u01, u10, u11 = [random.random() for _ in range(4)] # set the initial condition ut0(x) = u(x, t0) net_ut0 = FCNN(1, 1) cond_ut0 = DirichletBVP(x0, u00, x1, u10) ut0 = lambda x: cond_ut0.enforce(net_ut0, x) # set the Dirichlet boundary conditions g(t) = u(x0, t) and h(t) = u(x1, t) net_g, net_h = FCNN(1, 1), FCNN(1, 1) cond_g = IVP(t0, u00) cond_h = IVP(t0, u10) g = lambda t: cond_g.enforce(net_g, t) h = lambda t: cond_h.enforce(net_h, t) # set the Neumann boundary conditions p(t) = u'_x(x0, t) and q(t) = u'_x(x1, t) x = x0 * ones p0 = diff(ut0(x), x)[0, 0].item() x = x1 * ones q0 = diff(ut0(x), x)[0, 0].item() p1, q1 = random.random(), random.random() net_p, net_q = FCNN(1, 1), FCNN(1, 1) cond_p = DirichletBVP(t0, p0, t1, p1) cond_q = DirichletBVP(t0, q0, t1, q1) p = lambda t: cond_p.enforce(net_p, t) q = lambda t: cond_q.enforce(net_q, t) # initialize a random network net = FCNN(2, 1) # test Dirichlet-Dirichlet condition condition = IBVP1D(x0, x1, t0, ut0, x_min_val=g, x_max_val=h) x = torch.linspace(x0, x1, N_SAMPLES, requires_grad=True).view(-1, 1) t = t0 * ones assert all_close(condition.enforce(net, x, t), ut0(x)), "initial condition not satisfied" x = x0 * ones t = torch.linspace(t0, t1, N_SAMPLES, requires_grad=True).view(-1, 1) assert all_close(condition.enforce(net, x, t), g(t)), "left Dirichlet BC not satisfied" x = x1 * ones t = torch.linspace(t0, t1, N_SAMPLES, requires_grad=True).view(-1, 1) assert all_close(condition.enforce(net, x, t), h(t)), "right Dirichlet BC not satisfied" # test Dirichlet-Neumann condition condition = IBVP1D(x0, x1, t0, ut0, x_min_val=g, x_max_prime=q) x = torch.linspace(x0, x1, N_SAMPLES, requires_grad=True).view(-1, 1) t = t0 * ones assert all_close(condition.enforce(net, x, t), ut0(x)), "initial condition not satisfied" x = x0 * ones t = torch.linspace(t0, t1, N_SAMPLES, requires_grad=True).view(-1, 1) assert all_close(condition.enforce(net, x, t), g(t)), "left Dirichlet BC not satisfied" x = x1 * ones t = torch.linspace(t0, t1, N_SAMPLES, requires_grad=True).view(-1, 1) assert all_close(diff(condition.enforce(net, x, t), x), q(t)), "right Neumann BC not satisfied" # test Neumann-Dirichlet condition condition = IBVP1D(x0, x1, t0, ut0, x_min_prime=p, x_max_val=h) x = torch.linspace(x0, x1, N_SAMPLES, requires_grad=True).view(-1, 1) t = t0 * ones assert all_close(condition.enforce(net, x, t), ut0(x)), "initial condition not satisfied" x = x0 * ones t = torch.linspace(t0, t1, N_SAMPLES, requires_grad=True).view(-1, 1) assert all_close(diff(condition.enforce(net, x, t), x), p(t)), "left Neumann BC not satisfied" x = x1 * ones t = torch.linspace(t0, t1, N_SAMPLES, requires_grad=True).view(-1, 1) assert all_close(condition.enforce(net, x, t), h(t)), "right Dirichlet BC not satisfied" # test Neumann-Neumann condition condition = IBVP1D(x0, x1, t0, ut0, x_min_prime=p, x_max_prime=q) x = torch.linspace(x0, x1, N_SAMPLES, requires_grad=True).view(-1, 1) t = t0 * ones assert all_close(condition.enforce(net, x, t), ut0(x)), "initial condition not satisfied" x = x0 * ones t = torch.linspace(t0, t1, N_SAMPLES, requires_grad=True).view(-1, 1) assert all_close(diff(condition.enforce(net, x, t), x), p(t)), "left Neumann BC not satisfied" x = x1 * ones t = torch.linspace(t0, t1, N_SAMPLES, requires_grad=True).view(-1, 1) assert all_close(diff(condition.enforce(net, x, t), x), q(t)), "right Neumann BC not satisfied" # test unimplemented combination of conditions with raises(NotImplementedError): IBVP1D( t_min=0, t_min_val=lambda x: 0, x_min=0, x_min_val=None, x_min_prime=None, x_max=1, x_max_val=None, x_max_prime=None, ) with raises(NotImplementedError): IBVP1D( t_min=0, t_min_val=lambda x: 0, x_min=0, x_min_val=lambda t: 0, x_min_prime=lambda t: 0, x_max=1, x_max_val=None, x_max_prime=None, ) with raises(NotImplementedError): IBVP1D( t_min=0, t_min_val=lambda x: 0, x_min=0, x_min_val=None, x_min_prime=lambda t: 0, x_max=1, x_max_val=None, x_max_prime=None, )
import numpy as np import torch from neurodiffeq import diff from neurodiffeq.networks import FCNN from neurodiffeq.conditions import IVP, NoCondition from neurodiffeq.generators import Generator1D from neurodiffeq.solvers import GenericSolver, GenericSolution, BaseSolver from neurodiffeq.solvers import Solver1D, Solver2D, SolverSpherical from neurodiffeq.solvers import SolutionSphericalHarmonics SPECIFIC_SOLVERS = [Solver1D, Solver2D, SolverSpherical] N_SAMPLES = 64 T_MIN, T_MAX = 0.0, 1.0 DIFF_EQS = lambda u, t: [diff(u, t) + u] CONDITIONS = [IVP(0, 1)] @pytest.fixture def generators(): return dict( train=Generator1D(64, t_min=T_MIN, t_max=T_MAX, method='uniform'), valid=Generator1D(64, t_min=T_MIN, t_max=T_MAX, method='equally-spaced'), ) @pytest.fixture def solver(generators): return GenericSolver( diff_eqs=DIFF_EQS, conditions=CONDITIONS,