Exemplo n.º 1
0
def test_additional_loss_term():
    def particle_squarewell(y1, y2, t):
        return [(-1 / 2) * diff(y1, t, order=2) - 3 - (y2) * (y1), diff(y2, t)]

    def zero_y2(y1, y2, t):
        return torch.sum(y2**2)

    boundary_conditions = [
        DirichletBVP(t_0=0, u_0=0, t_1=2, u_1=0),
        DirichletBVP(t_0=0, u_0=0, t_1=2, u_1=0),
    ]

    solution_squarewell, loss_history = solve_system(
        ode_system=particle_squarewell,
        conditions=boundary_conditions,
        additional_loss_term=zero_y2,
        t_min=0.0,
        t_max=2.0,
        max_epochs=10,
    )
    assert isinstance(solution_squarewell, Solution)
    assert isinstance(loss_history, dict)
    keys = ['train_loss', 'valid_loss']
    for key in keys:
        assert key in loss_history
        assert isinstance(loss_history[key], list)
    assert len(loss_history[keys[0]]) == len(loss_history[keys[1]])
Exemplo n.º 2
0
def test_additional_loss_term():
    def particle_squarewell(y1, y2, t):
        return [
            (-1 / 2) * diff(y1, t, order=2) - 3 - (y2) * (y1),
            diff(y2, t)
        ]

    def zero_y2(y1, y2, t):
        return torch.sum(y2 ** 2)

    boundary_conditions = [
        DirichletBVP(t_0=0, x_0=0, t_1=2, x_1=0),
        DirichletBVP(t_0=0, x_0=0, t_1=2, x_1=0),
    ]

    solution_squarewell, _ = solve_system(
        ode_system=particle_squarewell, conditions=boundary_conditions,
        additional_loss_term=zero_y2,
        t_min=0.0, t_max=2.0,
        max_epochs=1000,
    )

    ts = np.linspace(0.0, 2.0, 100)
    _, y2 = solution_squarewell(ts, as_type='np')
    assert isclose(y2, np.zeros_like(y2), atol=0.02).all()
Exemplo n.º 3
0
def test_lotka_volterra():
    alpha, beta, delta, gamma = 1, 1, 1, 1
    lotka_volterra = lambda x, y, t : [diff(x, t) - (alpha*x  - beta*x*y),
                                       diff(y, t) - (delta*x*y - gamma*y)]
    init_vals_lv = [
        IVP(t_0=0.0, x_0=1.5),
        IVP(t_0=0.0, x_0=1.0)
    ]
    nets_lv = [
        FCNN(hidden_units=(32, 32), actv=SinActv),
        FCNN(hidden_units=(32, 32), actv=SinActv),
    ]
    solution_lv, _ = solve_system(ode_system=lotka_volterra, conditions=init_vals_lv,
                                  t_min=0.0, t_max=12, nets=nets_lv, max_epochs=12000,
                                  monitor=Monitor(t_min=0.0, t_max=12, check_every=100))
    ts = np.linspace(0, 12, 100)
    prey_net, pred_net = solution_lv(ts, as_type='np')

    def dPdt(P, t):
        return [P[0]*alpha - beta*P[0]*P[1], delta*P[0]*P[1] - gamma*P[1]]
    P0 = [1.5, 1.0]
    Ps = odeint(dPdt, P0, ts)
    prey_num = Ps[:,0]
    pred_num = Ps[:,1]
    assert isclose(prey_net, prey_num, atol=0.1).all()
    assert isclose(pred_net, pred_num, atol=0.1).all()
    print('Lotka Volterra test passed.')
Exemplo n.º 4
0
def test_ode_system():

    parametric_circle = lambda x1, x2, t: [diff(x1, t) - x2, diff(x2, t) + x1]
    init_vals_pc = [IVP(t_0=0.0, x_0=0.0), IVP(t_0=0.0, x_0=1.0)]

    solution_pc, _ = solve_system(
        ode_system=parametric_circle,
        conditions=init_vals_pc,
        t_min=0.0,
        t_max=2 * np.pi,
        max_epochs=5000,
    )

    ts = np.linspace(0, 2 * np.pi, 100)
    x1_net, x2_net = solution_pc(ts, as_type='np')
    x1_ana, x2_ana = np.sin(ts), np.cos(ts)
    assert isclose(x1_net, x1_ana, atol=0.1).all()
    assert isclose(x2_net, x2_ana, atol=0.1).all()
    print('solve_system basic test passed.')
Exemplo n.º 5
0
def test_ode_system():
    parametric_circle = lambda u1, u2, t: [diff(u1, t) - u2, diff(u2, t) + u1]
    init_vals_pc = [IVP(t_0=0.0, u_0=0.0), IVP(t_0=0.0, u_0=1.0)]

    solution_pc, loss_history = solve_system(
        ode_system=parametric_circle,
        conditions=init_vals_pc,
        t_min=0.0,
        t_max=2 * np.pi,
        max_epochs=10,
    )

    assert isinstance(solution_pc, Solution)
    assert isinstance(loss_history, dict)
    keys = ['train_loss', 'valid_loss']
    for key in keys:
        assert key in loss_history
        assert isinstance(loss_history[key], list)
    assert len(loss_history[keys[0]]) == len(loss_history[keys[1]])