Ejemplo n.º 1
0
def vjp_solve_eval_impl(
    g: np.array,
    fenics_solution: fenics.Function,
    fenics_residual: ufl.Form,
    fenics_inputs: List[FenicsVariable],
    bcs: List[fenics.DirichletBC],
) -> Tuple[np.array]:
    """Computes the gradients of the output with respect to the inputs."""
    # Convert tangent covector (adjoint) to a FEniCS variable
    adj_value = numpy_to_fenics(g, fenics_solution)
    adj_value = adj_value.vector()

    F = fenics_residual
    u = fenics_solution
    V = u.function_space()
    dFdu = fenics.derivative(F, u)
    adFdu = ufl.adjoint(
        dFdu, reordered_arguments=ufl.algorithms.extract_arguments(dFdu)
    )

    u_adj = fenics.Function(V)
    adj_F = ufl.action(adFdu, u_adj)
    adj_F = ufl.replace(adj_F, {u_adj: fenics.TrialFunction(V)})
    adj_F_assembled = fenics.assemble(adj_F)

    if len(bcs) != 0:
        for bc in bcs:
            bc.homogenize()
        hbcs = bcs

        for bc in hbcs:
            bc.apply(adj_F_assembled)
            bc.apply(adj_value)

    fenics.solve(adj_F_assembled, u_adj.vector(), adj_value)

    fenics_grads = []
    for fenics_input in fenics_inputs:
        if isinstance(fenics_input, fenics.Function):
            V = fenics_input.function_space()
        dFdm = fenics.derivative(F, fenics_input, fenics.TrialFunction(V))
        adFdm = fenics.adjoint(dFdm)
        result = fenics.assemble(-adFdm * u_adj)
        if isinstance(fenics_input, fenics.Constant):
            fenics_grad = fenics.Constant(result.sum())
        else:  # fenics.Function
            fenics_grad = fenics.Function(V, result)
        fenics_grads.append(fenics_grad)

    # Convert FEniCS gradients to jax array representation
    jax_grads = (
        None if fg is None else np.asarray(fenics_to_numpy(fg)) for fg in fenics_grads
    )

    jax_grad_tuple = tuple(jax_grads)

    return jax_grad_tuple
Ejemplo n.º 2
0
def construct_forward_func_and_derivatives(
    z_func_space,
    x_func_space,
    define_forms,
    define_boundary_conditions,
    observation_coordinates,
    prior_covar_sqrt=None,
    bilinear_form_is_symmetric=False,
):

    # Get problem dimensions
    dim_x = x_func_space.dim()
    dim_y = observation_coordinates.shape[0]
    dim_z = z_func_space.dim()

    # Set up observation operator
    # (point observations at DOFs closest to observation coordinates)
    dof_coordinates = x_func_space.tabulate_dof_coordinates()
    observation_dof_indices = np.argmin(
        ((dof_coordinates[None, :, :] -
          observation_coordinates[:, None])**2).sum(-1),
        -1,
    )
    observation_matrix = np.zeros((dim_y, dim_x))
    observation_matrix[np.arange(dim_y), observation_dof_indices] = 1

    # Default to identity prior covariance if prior_covar_sqrt is None
    if prior_covar_sqrt is None:
        prior_covar_sqrt = LinearOperator(shape=(dim_z, dim_z),
                                          matvec=lambda x: x,
                                          rmatvec=lambda x: x)

    # Construct function objects
    z = fenics.Function(z_func_space)
    x = fenics.Function(x_func_space)
    x_trial = fenics.TrialFunction(x_func_space)
    x_test = fenics.TestFunction(x_func_space)
    h = fenics.Function(x_func_space)
    v_H = fenics.Function(x_func_space)
    k = fenics.Function(x_func_space)
    dAx_dz_m = fenics.Function(x_func_space)
    m_list = [fenics.Function(z_func_space) for _ in range(dim_y)]
    h_A_inv_list = [fenics.Function(x_func_space) for _ in range(dim_y)]
    A_inv_dAx_dz_m_list = [fenics.Function(x_func_space) for _ in range(dim_y)]

    # Construct bilinear and linear forms defining variational form of problem
    A, b = define_forms(z, x_test, x_trial)
    boundary_conditions = define_boundary_conditions(x_func_space)

    # Precompute additional forms required for calculating first-derivatives
    if not bilinear_form_is_symmetric:
        adjoint_A = fenics.adjoint(A)
    dAx_dz = fenics.derivative(A(x_test, x, coefficients={}), z)
    adjoint_dAx_dz = fenics.adjoint(dAx_dz)
    k_dAx_dz = fenics.derivative(A(k, x, coefficients={}), z)

    # Create homogenized boundary conditions for solving in adjoint pass
    homogenized_boundary_conditions = define_boundary_conditions(x_func_space)
    for boundary_condition in homogenized_boundary_conditions:
        boundary_condition.homogenize()

    # Preallocate numpy array for storing Jacobian
    dy_dv = np.full((dim_y, dim_z), np.nan)

    # Precompute forms required for calculating second-derivatives
    g_1, g_2, g_3 = 0, 0, 0
    for h_A_inv, m, A_inv_dAx_dz_m in zip(h_A_inv_list, m_list,
                                          A_inv_dAx_dz_m_list):
        g_1 += fenics.derivative(A(h_A_inv, A_inv_dAx_dz_m, coefficients={}),
                                 z)
        g_2 += fenics.derivative(A(h_A_inv, x_trial, coefficients={}), z, m)
        g_3 -= fenics.derivative(
            fenics.derivative(A(h_A_inv, x, coefficients={}), z, m), z)

    def solution_func(v_array):
        z_array = prior_covar_sqrt @ v_array
        z.vector().set_local(z_array)
        fenics.solve(A == b, x, bcs=boundary_conditions)
        return x.vector().get_local()

    def forward_func(v_array):
        return solution_func(v_array)[observation_dof_indices]

    def _get_solvers_and_y():
        if not bilinear_form_is_symmetric:
            # Homogenized and original boundary conditions have equivalent effect on matrix
            # corresponding to assembled bilinear form
            A_solver = construct_solver(A, boundary_conditions)
            b_vector = fenics.assemble(b)
            adjoint_A_solver = construct_solver(
                adjoint_A, homogenized_boundary_conditions)
            solve_with_boundary_conditions(A_solver, boundary_conditions,
                                           b_vector, x)
        else:
            A_matrix, b_vector = fenics.assemble_system(
                A, b, boundary_conditions)
            A_solver = fenics.LUSolver(A_matrix)
            A_solver.parameters["symmetric"] = True
            adjoint_A_solver = A_solver  # A is symmetric therefore adjoint(A) == A
            A_solver.solve(x.vector(), b_vector)
        y = (x.vector()).get_local()[observation_dof_indices]
        return A_solver, adjoint_A_solver, y

    def vjp_forward_func(v_array):
        z_array = prior_covar_sqrt @ v_array
        z.vector().set_local(z_array)
        _, adjoint_A_solver, y = _get_solvers_and_y()

        def vjp(v):
            v_H.vector().set_local(v @ observation_matrix)
            solve_with_boundary_conditions(adjoint_A_solver,
                                           homogenized_boundary_conditions,
                                           v_H.vector(), k)
            return -prior_covar_sqrt.T @ fenics.assemble(k_dAx_dz).get_local()

        return vjp, y

    def _jacob_forward_func(v_array):
        z_array = prior_covar_sqrt @ v_array
        z.vector().set_local(z_array)
        A_solver, adjoint_A_solver, y = _get_solvers_and_y()
        adjoint_dAx_dz_matrix = fenics.assemble(adjoint_dAx_dz)
        for dy_dv_row, h_arr, h_A_inv in zip(dy_dv, observation_matrix,
                                             h_A_inv_list):
            h.vector().set_local(h_arr)
            solve_with_boundary_conditions(adjoint_A_solver,
                                           homogenized_boundary_conditions,
                                           h.vector(), h_A_inv)
            dy_dv_row[:] = (-prior_covar_sqrt.T @ (
                adjoint_dAx_dz_matrix * h_A_inv.vector()).get_local())
        return dy_dv, y, A_solver, adjoint_A_solver, adjoint_dAx_dz_matrix

    def jacob_forward_func(v_array):
        return _jacob_forward_func(v_array)[:2]

    def mhp_forward_func(v_array):
        (
            dy_dv,
            y,
            A_solver,
            adjoint_A_solver,
            adjoint_dAx_dz_matrix,
        ) = _jacob_forward_func(v_array)

        def mhp(matrix):
            for m_arr, m, A_inv_dAx_dz_m in zip(matrix, m_list,
                                                A_inv_dAx_dz_m_list):
                m.vector().set_local(prior_covar_sqrt @ m_arr)
                adjoint_dAx_dz_matrix.transpmult(m.vector(), dAx_dz_m.vector())
                solve_with_boundary_conditions(
                    A_solver,
                    homogenized_boundary_conditions,
                    dAx_dz_m.vector(),
                    A_inv_dAx_dz_m,
                )
            g = (fenics.assemble(g_1 + g_3) +
                 adjoint_dAx_dz_matrix * solve_with_boundary_conditions(
                     adjoint_A_solver,
                     homogenized_boundary_conditions,
                     fenics.assemble(g_2),
                     x_func_space,
                 ).vector())
            return prior_covar_sqrt.T @ g.get_local()

        return mhp, dy_dv, y

    return (
        solution_func,
        forward_func,
        vjp_forward_func,
        jacob_forward_func,
        mhp_forward_func,
        (dim_x, dim_y, dim_z),
    )