예제 #1
0
def test_dirichlet_bvp_spherical():
    r0, r1 = x0, x1
    r2 = (r0 + r1) / 2

    no_condition = NoCondition()
    # B.C. for the interior boundary (r_min)
    net_f = FCNN(2, 1)
    f = lambda th, ph: no_condition.enforce(net_f, th, ph)

    # B.C. for the exterior boundary (r_max)
    net_g = FCNN(2, 1)
    g = lambda th, ph: no_condition.enforce(net_g, th, ph)

    condition = DirichletBVPSpherical(r_0=r0, f=f, r_1=r1, g=g)

    net = FCNN(3, 1)
    theta = torch.rand(N_SAMPLES, 1) * np.pi
    phi = torch.rand(N_SAMPLES, 1) * 2 * np.pi
    r = r0 * ones
    assert all_close(condition.enforce(net, r, theta, phi),
                     f(theta, phi)), "inner Dirichlet BC not satisfied"
    r = r1 * ones
    assert all_close(condition.enforce(net, r, theta, phi),
                     g(theta, phi)), "inner Dirichlet BC not satisfied"

    condition = DirichletBVPSpherical(r_0=r2, f=f)
    r = r2 * ones
    assert all_close(condition.enforce(net, r, theta, phi),
                     f(theta, phi)), "single ended BC not satisfied"
예제 #2
0
def test_no_condition():
    N_INPUTS = 5
    N_OUTPUTS = 5

    for n_in, n_out in zip(range(1, N_INPUTS), range(1, N_OUTPUTS)):
        xs = [torch.rand(N_SAMPLES, 1, requires_grad=True) for _ in range(n_in)]
        net = FCNN(n_in, n_out)

        cond = NoCondition()
        y_cond = cond.enforce(net, *xs)
        y_raw = net(torch.cat(xs, dim=1))
        assert (y_cond == y_raw).all()
예제 #3
0
def solver():
    return Solver1D(
        ode_system=lambda u, t: [u + diff(u, t)],
        conditions=[NoCondition()],
        t_min=0.0,
        t_max=1.0,
    )
예제 #4
0
def test_monitor_2d(solution_style, history):
    monitor = Monitor2D((0, 0), (1, 1),
                        check_every=100,
                        solution_style=solution_style)
    nets = [FCNN(2, 1) for _ in range(N_FUNCTIONS)]
    conditions = [NoCondition() for _ in range(N_FUNCTIONS)]
    monitor.check(nets, conditions, history)
예제 #5
0
def test_legacy_max_degree_in_solution_spherical_harmonics():
    with pytest.warns(FutureWarning):
        SolutionSphericalHarmonics(
            nets=[FCNN(1, 10)],
            conditions=[NoCondition()],
            max_degree=4,
        )
예제 #6
0
def test_stream_plot_monitor(mask_fn, nx, ny, specify_field_names):
    nets = [FCNN(2, 1, hidden_units=(3, )) for _ in range(5)]
    conditions = [NoCondition() for _ in nets]
    pairs = [(0, 1), (2, 3), (0, 3), 4, 2]

    if specify_field_names:
        field_names = [str(i) for i in range(len(pairs))]
    else:
        field_names = None
    monitor = StreamPlotMonitor2D(
        xy_min=(-1, -1),
        xy_max=(1, 1),
        nx=nx,
        ny=ny,
        pairs=pairs,
        mask_fn=mask_fn,
        equal_aspect=True,
        field_names=field_names,
    )

    if specify_field_names:
        with pytest.raises(ValueError):
            StreamPlotMonitor2D(
                xy_min=(-1, -1),
                xy_max=(1, 1),
                pairs=[0, 1],
                field_names=['a', 'b', 'c'],
            )

    monitor.check(nets, conditions, history=None)
    monitor.check(nets[::-1], conditions, history=None)
예제 #7
0
def test_inf_dirichlet_bvp_spherical():
    r0 = random.random()
    r1 = 1e15
    no_condition = NoCondition()
    net_f, net_g = FCNN(2, 1), FCNN(2, 1)

    # B.C. for the interior boundary (r=r_min)
    f = lambda th, ph: no_condition.enforce(net_f, th, ph)
    # B.C. for the exterior boundary (r=infinity)
    g = lambda th, ph: no_condition.enforce(net_g, th, ph)

    net = FCNN(3, 1)
    condition = InfDirichletBVPSpherical(r_0=r0, f=f, g=g, order=1)
    theta = torch.rand(10, 1) * np.pi
    phi = torch.rand(10, 1) * (2 * np.pi)

    r = r0 * ones
    assert all_close(condition.enforce(net, r, theta, phi), f(theta, phi)), "inner DirichletBC not satisfied"
    r = r1 * ones
    assert all_close(condition.enforce(net, r, theta, phi), g(theta, phi)), "Infinity DirichletBC not satisfied"
예제 #8
0
def test_train_generator_spherical():
    pde = laplacian_spherical
    condition = NoCondition()
    train_generator = GeneratorSpherical(size=64,
                                         r_min=0.,
                                         r_max=1.,
                                         method='equally-spaced-noisy')
    r, th, ph = train_generator.get_examples()
    assert (0. < r.min()) and (r.max() < 1.)
    assert (0. <= th.min()) and (th.max() <= np.pi)
    assert (0. <= ph.min()) and (ph.max() <= 2 * np.pi)

    valid_generator = GeneratorSpherical(size=64,
                                         r_min=1.,
                                         r_max=1.,
                                         method='equally-radius-noisy')
    r, th, ph = valid_generator.get_examples()
    assert (r == 1).all()
    assert (0. <= th.min()) and (th.max() <= np.pi)
    assert (0. <= ph.min()) and (ph.max() <= 2 * np.pi)

    solve_spherical(pde,
                    condition,
                    0.0,
                    1.0,
                    train_generator=train_generator,
                    valid_generator=valid_generator,
                    max_epochs=1)
    with raises(ValueError):
        _ = GeneratorSpherical(64, method='bad_generator')

    with raises(ValueError):
        _ = GeneratorSpherical(64, r_min=-1.0)

    with raises(ValueError):
        _ = GeneratorSpherical(64, r_min=1.0, r_max=0.0)
예제 #9
0
def U(x):
    cond = NoCondition()
    nets = [FCNN(3, 1) for _ in range(3)]
    return tuple(cond.enforce(net, *x) for net in nets)
예제 #10
0
def u(x):
    cond = NoCondition()
    net = FCNN(3, 1)
    return cond.enforce(net, *x)
예제 #11
0
def scalar_field(x):
    cond = NoCondition()
    return cond.enforce(FCNN(3, 1), *x)
예제 #12
0
def vector_field(x):
    cond = NoCondition()
    return tuple(cond.enforce(FCNN(3, 1), *x) for _ in range(3))
예제 #13
0
def test_monitor_1d(history):
    monitor = Monitor1D(0, 1)
    nets = [FCNN() for _ in range(N_FUNCTIONS)]
    conditions = [NoCondition() for _ in range(N_FUNCTIONS)]
    monitor.check(nets, conditions, history=history)
예제 #14
0
def test_metrics_mointor(history):
    monitor = MetricsMonitor(check_every=10)
    nets = [FCNN(2, 1) for _ in range(N_FUNCTIONS)]
    conditions = [NoCondition() for _ in range(N_FUNCTIONS)]
    monitor.check(nets, conditions, history)