Exemplo n.º 1
0
    def local_solver_calls(self, A, rhs, x, elim_fields):
        """Provides solver callbacks for inverting local operators
        and reconstructing eliminated fields.

        :arg A: A Slate Tensor containing the mixed bilinear form.
        :arg rhs: A firedrake function for the right-hand side.
        :arg x: A firedrake function for the solution.
        :arg elim_fields: An iterable of eliminated field indices
                          to recover.
        """

        from firedrake.slate.static_condensation.la_utils import backward_solve
        from firedrake.assemble import create_assembly_callable

        fields = x.split()
        systems = backward_solve(A, rhs, x, reconstruct_fields=elim_fields)

        local_solvers = []
        for local_system in systems:
            Ae = local_system.lhs
            be = local_system.rhs
            i, = local_system.field_idx
            local_solve = Ae.solve(be, decomposition="PartialPivLU")
            solve_call = create_assembly_callable(
                local_solve,
                tensor=fields[i],
                form_compiler_parameters=self.cxt.fc_params)
            local_solvers.append(solve_call)

        return local_solvers
Exemplo n.º 2
0
    def local_solver_calls(self, A, rhs, x, elim_fields):
        """Provides solver callbacks for inverting local operators
        and reconstructing eliminated fields.

        :arg A: A Slate Tensor containing the mixed bilinear form.
        :arg rhs: A firedrake function for the right-hand side.
        :arg x: A firedrake function for the solution.
        :arg elim_fields: An iterable of eliminated field indices
                          to recover.
        """

        from firedrake.slate.static_condensation.la_utils import backward_solve
        from firedrake.assemble import create_assembly_callable

        fields = x.split()
        systems = backward_solve(A, rhs, x, reconstruct_fields=elim_fields)

        local_solvers = []
        for local_system in systems:
            Ae = local_system.lhs
            be = local_system.rhs
            i, = local_system.field_idx
            local_solve = Ae.solve(be, decomposition="PartialPivLU")
            solve_call = create_assembly_callable(
                local_solve,
                tensor=fields[i],
                form_compiler_parameters=self.cxt.fc_params)
            local_solvers.append(solve_call)

        return local_solvers