def vjp_solve_eval_impl( g: np.array, fenics_solution: fenics.Function, fenics_residual: ufl.Form, fenics_inputs: List[FenicsVariable], bcs: List[fenics.DirichletBC], ) -> Tuple[np.array]: """Computes the gradients of the output with respect to the inputs.""" # Convert tangent covector (adjoint) to a FEniCS variable adj_value = numpy_to_fenics(g, fenics_solution) adj_value = adj_value.vector() F = fenics_residual u = fenics_solution V = u.function_space() dFdu = fenics.derivative(F, u) adFdu = ufl.adjoint( dFdu, reordered_arguments=ufl.algorithms.extract_arguments(dFdu) ) u_adj = fenics.Function(V) adj_F = ufl.action(adFdu, u_adj) adj_F = ufl.replace(adj_F, {u_adj: fenics.TrialFunction(V)}) adj_F_assembled = fenics.assemble(adj_F) if len(bcs) != 0: for bc in bcs: bc.homogenize() hbcs = bcs for bc in hbcs: bc.apply(adj_F_assembled) bc.apply(adj_value) fenics.solve(adj_F_assembled, u_adj.vector(), adj_value) fenics_grads = [] for fenics_input in fenics_inputs: if isinstance(fenics_input, fenics.Function): V = fenics_input.function_space() dFdm = fenics.derivative(F, fenics_input, fenics.TrialFunction(V)) adFdm = fenics.adjoint(dFdm) result = fenics.assemble(-adFdm * u_adj) if isinstance(fenics_input, fenics.Constant): fenics_grad = fenics.Constant(result.sum()) else: # fenics.Function fenics_grad = fenics.Function(V, result) fenics_grads.append(fenics_grad) # Convert FEniCS gradients to jax array representation jax_grads = ( None if fg is None else np.asarray(fenics_to_numpy(fg)) for fg in fenics_grads ) jax_grad_tuple = tuple(jax_grads) return jax_grad_tuple
def construct_forward_func_and_derivatives( z_func_space, x_func_space, define_forms, define_boundary_conditions, observation_coordinates, prior_covar_sqrt=None, bilinear_form_is_symmetric=False, ): # Get problem dimensions dim_x = x_func_space.dim() dim_y = observation_coordinates.shape[0] dim_z = z_func_space.dim() # Set up observation operator # (point observations at DOFs closest to observation coordinates) dof_coordinates = x_func_space.tabulate_dof_coordinates() observation_dof_indices = np.argmin( ((dof_coordinates[None, :, :] - observation_coordinates[:, None])**2).sum(-1), -1, ) observation_matrix = np.zeros((dim_y, dim_x)) observation_matrix[np.arange(dim_y), observation_dof_indices] = 1 # Default to identity prior covariance if prior_covar_sqrt is None if prior_covar_sqrt is None: prior_covar_sqrt = LinearOperator(shape=(dim_z, dim_z), matvec=lambda x: x, rmatvec=lambda x: x) # Construct function objects z = fenics.Function(z_func_space) x = fenics.Function(x_func_space) x_trial = fenics.TrialFunction(x_func_space) x_test = fenics.TestFunction(x_func_space) h = fenics.Function(x_func_space) v_H = fenics.Function(x_func_space) k = fenics.Function(x_func_space) dAx_dz_m = fenics.Function(x_func_space) m_list = [fenics.Function(z_func_space) for _ in range(dim_y)] h_A_inv_list = [fenics.Function(x_func_space) for _ in range(dim_y)] A_inv_dAx_dz_m_list = [fenics.Function(x_func_space) for _ in range(dim_y)] # Construct bilinear and linear forms defining variational form of problem A, b = define_forms(z, x_test, x_trial) boundary_conditions = define_boundary_conditions(x_func_space) # Precompute additional forms required for calculating first-derivatives if not bilinear_form_is_symmetric: adjoint_A = fenics.adjoint(A) dAx_dz = fenics.derivative(A(x_test, x, coefficients={}), z) adjoint_dAx_dz = fenics.adjoint(dAx_dz) k_dAx_dz = fenics.derivative(A(k, x, coefficients={}), z) # Create homogenized boundary conditions for solving in adjoint pass homogenized_boundary_conditions = define_boundary_conditions(x_func_space) for boundary_condition in homogenized_boundary_conditions: boundary_condition.homogenize() # Preallocate numpy array for storing Jacobian dy_dv = np.full((dim_y, dim_z), np.nan) # Precompute forms required for calculating second-derivatives g_1, g_2, g_3 = 0, 0, 0 for h_A_inv, m, A_inv_dAx_dz_m in zip(h_A_inv_list, m_list, A_inv_dAx_dz_m_list): g_1 += fenics.derivative(A(h_A_inv, A_inv_dAx_dz_m, coefficients={}), z) g_2 += fenics.derivative(A(h_A_inv, x_trial, coefficients={}), z, m) g_3 -= fenics.derivative( fenics.derivative(A(h_A_inv, x, coefficients={}), z, m), z) def solution_func(v_array): z_array = prior_covar_sqrt @ v_array z.vector().set_local(z_array) fenics.solve(A == b, x, bcs=boundary_conditions) return x.vector().get_local() def forward_func(v_array): return solution_func(v_array)[observation_dof_indices] def _get_solvers_and_y(): if not bilinear_form_is_symmetric: # Homogenized and original boundary conditions have equivalent effect on matrix # corresponding to assembled bilinear form A_solver = construct_solver(A, boundary_conditions) b_vector = fenics.assemble(b) adjoint_A_solver = construct_solver( adjoint_A, homogenized_boundary_conditions) solve_with_boundary_conditions(A_solver, boundary_conditions, b_vector, x) else: A_matrix, b_vector = fenics.assemble_system( A, b, boundary_conditions) A_solver = fenics.LUSolver(A_matrix) A_solver.parameters["symmetric"] = True adjoint_A_solver = A_solver # A is symmetric therefore adjoint(A) == A A_solver.solve(x.vector(), b_vector) y = (x.vector()).get_local()[observation_dof_indices] return A_solver, adjoint_A_solver, y def vjp_forward_func(v_array): z_array = prior_covar_sqrt @ v_array z.vector().set_local(z_array) _, adjoint_A_solver, y = _get_solvers_and_y() def vjp(v): v_H.vector().set_local(v @ observation_matrix) solve_with_boundary_conditions(adjoint_A_solver, homogenized_boundary_conditions, v_H.vector(), k) return -prior_covar_sqrt.T @ fenics.assemble(k_dAx_dz).get_local() return vjp, y def _jacob_forward_func(v_array): z_array = prior_covar_sqrt @ v_array z.vector().set_local(z_array) A_solver, adjoint_A_solver, y = _get_solvers_and_y() adjoint_dAx_dz_matrix = fenics.assemble(adjoint_dAx_dz) for dy_dv_row, h_arr, h_A_inv in zip(dy_dv, observation_matrix, h_A_inv_list): h.vector().set_local(h_arr) solve_with_boundary_conditions(adjoint_A_solver, homogenized_boundary_conditions, h.vector(), h_A_inv) dy_dv_row[:] = (-prior_covar_sqrt.T @ ( adjoint_dAx_dz_matrix * h_A_inv.vector()).get_local()) return dy_dv, y, A_solver, adjoint_A_solver, adjoint_dAx_dz_matrix def jacob_forward_func(v_array): return _jacob_forward_func(v_array)[:2] def mhp_forward_func(v_array): ( dy_dv, y, A_solver, adjoint_A_solver, adjoint_dAx_dz_matrix, ) = _jacob_forward_func(v_array) def mhp(matrix): for m_arr, m, A_inv_dAx_dz_m in zip(matrix, m_list, A_inv_dAx_dz_m_list): m.vector().set_local(prior_covar_sqrt @ m_arr) adjoint_dAx_dz_matrix.transpmult(m.vector(), dAx_dz_m.vector()) solve_with_boundary_conditions( A_solver, homogenized_boundary_conditions, dAx_dz_m.vector(), A_inv_dAx_dz_m, ) g = (fenics.assemble(g_1 + g_3) + adjoint_dAx_dz_matrix * solve_with_boundary_conditions( adjoint_A_solver, homogenized_boundary_conditions, fenics.assemble(g_2), x_func_space, ).vector()) return prior_covar_sqrt.T @ g.get_local() return mhp, dy_dv, y return ( solution_func, forward_func, vjp_forward_func, jacob_forward_func, mhp_forward_func, (dim_x, dim_y, dim_z), )