Exemple #1
0
def test_dirichlet_bvp_spherical_basis():
    N_COMPONENTS = 25
    r0, r1 = x0, x1
    r2 = (r0 + r1) / 2

    R0 = torch.rand(N_SAMPLES, N_COMPONENTS)
    R1 = torch.rand(N_SAMPLES, N_COMPONENTS)
    R2 = torch.rand(N_SAMPLES, N_COMPONENTS)

    condition = DirichletBVPSphericalBasis(r_0=r0, R_0=R0, r_1=r1, R_1=R1)
    net = FCNN(1, N_COMPONENTS)

    r = r0 * ones
    assert all_close(condition.enforce(net, r), R0), "inner Dirichlet BC not satisfied"
    r = r1 * ones
    assert all_close(condition.enforce(net, r), R1), "outer Dirichlet BC not satisfied"

    condition = DirichletBVPSphericalBasis(r_0=r2, R_0=R2)
    r = r2 * ones
    assert all_close(condition.enforce(net, r), R2), "single ended BC not satisfied"
def test_electric_potential_gaussian_charged_density():
    # total charge
    Q = 1.
    # standard deviation of gaussian
    sigma = 1.
    # medium permittivity
    epsilon = 1.
    # Coulomb constant
    k = 1 / (4 * np.pi * epsilon)
    # coefficient of gaussian term
    gaussian_coeff = Q / (sigma**3) / np.power(2 * np.pi, 1.5)
    # distribution of charge
    rho_f = lambda r: gaussian_coeff * torch.exp(-r.pow(2) / (2 * sigma**2))
    # analytic solution, refer to https://en.wikipedia.org/wiki/Poisson%27s_equation
    analytic_solution = lambda r, th, ph: (k * Q / r) * torch.erf(r / (np.sqrt(
        2) * sigma))

    # interior and exterior radius
    r_0, r_1 = 0.1, 3.
    # values at interior and exterior boundary
    v_0 = (k * Q / r_0) * erf(r_0 / (np.sqrt(2) * sigma))
    v_1 = (k * Q / r_1) * erf(r_1 / (np.sqrt(2) * sigma))

    def validate(solution):
        generator = GeneratorSpherical(512, r_min=r_0, r_max=r_1)
        rs, thetas, phis = generator.get_examples()
        us = solution(rs, thetas, phis, to_numpy=True)
        vs = analytic_solution(rs, thetas, phis).detach().cpu().numpy()
        assert us.shape == vs.shape

    # solving the problem using normal network (subject to the influence of polar singularity of laplacian operator)

    pde1 = lambda u, r, th, ph: laplacian_spherical(u, r, th, ph) + rho_f(
        r) / epsilon
    condition1 = DirichletBVPSpherical(r_0, lambda th, ph: v_0, r_1,
                                       lambda th, ph: v_1)
    monitor1 = MonitorSpherical(r_0, r_1, check_every=50)
    with pytest.warns(FutureWarning):
        solution1, metrics_history = solve_spherical(
            pde1,
            condition1,
            r_0,
            r_1,
            max_epochs=2,
            return_best=True,
            analytic_solution=analytic_solution,
            monitor=monitor1,
        )
    validate(solution1)

    # solving the problem using spherical harmonics (laplcian computation is optimized)
    max_degree = 2
    harmonics_fn = RealSphericalHarmonics(max_degree=max_degree)
    harmonic_laplacian = HarmonicsLaplacian(max_degree=max_degree)
    pde2 = lambda R, r, th, ph: harmonic_laplacian(R, r, th, ph) + rho_f(
        r) / epsilon
    R_0 = torch.tensor([v_0 * 2] + [0 for _ in range((max_degree + 1)**2 - 1)])
    R_1 = torch.tensor([v_1 * 2] + [0 for _ in range((max_degree + 1)**2 - 1)])

    def analytic_solution2(r, th, ph):
        sol = torch.zeros(r.shape[0], (max_degree + 1)**2)
        sol[:, 0:1] = 2 * analytic_solution(r, th, ph)
        return sol

    condition2 = DirichletBVPSphericalBasis(r_0=r_0,
                                            R_0=R_0,
                                            r_1=r_1,
                                            R_1=R_1,
                                            max_degree=max_degree)
    monitor2 = MonitorSphericalHarmonics(r_0,
                                         r_1,
                                         check_every=50,
                                         harmonics_fn=harmonics_fn)
    net2 = FCNN(n_input_units=1, n_output_units=(max_degree + 1)**2)
    with pytest.warns(FutureWarning):
        solution2, metrics_history = solve_spherical(
            pde2,
            condition2,
            r_0,
            r_1,
            net=net2,
            max_epochs=2,
            return_best=True,
            analytic_solution=analytic_solution2,
            monitor=monitor2,
            harmonics_fn=harmonics_fn,
        )
    validate(solution2)