def test_additional_loss_term():
    def particle_squarewell(y1, y2, t):
        return [
            (-1 / 2) * diff(y1, t, order=2) - 3 - (y2) * (y1),
            diff(y2, t)
        ]

    def zero_y2(y1, y2, t):
        return torch.sum(y2 ** 2)

    boundary_conditions = [
        DirichletBVP(t_0=0, x_0=0, t_1=2, x_1=0),
        DirichletBVP(t_0=0, x_0=0, t_1=2, x_1=0),
    ]

    solution_squarewell, _ = solve_system(
        ode_system=particle_squarewell, conditions=boundary_conditions,
        additional_loss_term=zero_y2,
        t_min=0.0, t_max=2.0,
        max_epochs=1000,
    )

    ts = np.linspace(0.0, 2.0, 100)
    _, y2 = solution_squarewell(ts, as_type='np')
    assert isclose(y2, np.zeros_like(y2), atol=0.02).all()
Exemple #2
0
def test_additional_loss_term():
    def particle_squarewell(y1, y2, t):
        return [(-1 / 2) * diff(y1, t, order=2) - 3 - (y2) * (y1), diff(y2, t)]

    def zero_y2(y1, y2, t):
        return torch.sum(y2**2)

    boundary_conditions = [
        DirichletBVP(t_0=0, u_0=0, t_1=2, u_1=0),
        DirichletBVP(t_0=0, u_0=0, t_1=2, u_1=0),
    ]

    solution_squarewell, loss_history = solve_system(
        ode_system=particle_squarewell,
        conditions=boundary_conditions,
        additional_loss_term=zero_y2,
        t_min=0.0,
        t_max=2.0,
        max_epochs=10,
    )
    assert isinstance(solution_squarewell, Solution)
    assert isinstance(loss_history, dict)
    keys = ['train_loss', 'valid_loss']
    for key in keys:
        assert key in loss_history
        assert isinstance(loss_history[key], list)
    assert len(loss_history[keys[0]]) == len(loss_history[keys[1]])
def test_ode_bvp():
    oscillator = lambda x, t: diff(x, t, order=2) + x
    bound_val_ho = DirichletBVP(t_0=0.0, x_0=0.0, t_1=1.5*np.pi, x_1=-1.0)
    solution_ho, _ = solve(ode=oscillator, condition=bound_val_ho,
                           max_epochs=3000,
                           t_min=0.0, t_max=1.5*np.pi)
    ts = np.linspace(0, 1.5*np.pi, 100)
    x_net = solution_ho(ts, as_type='np')
    x_ana = np.sin(ts)
    assert isclose(x_net, x_ana, atol=0.1).all()
    print('BVP basic test passed.')