Esempio n. 1
0
    def initialize(self, pc):
        # Make a new DM.
        # Hook up a (new) coarsen routine on that DM.
        # Make a new PC, of type MG.
        # Assign the DM to that PC.

        odm = pc.getDM()
        ctx = get_appctx(odm)

        test, trial = ctx.J.arguments()
        if test.function_space() != trial.function_space():
            raise NotImplementedError("test and trial spaces must be the same")

        prefix = pc.getOptionsPrefix()
        options_prefix = prefix + "pmg_"
        pdm = PETSc.DMShell().create(comm=pc.comm)
        pdm.setOptionsPrefix(options_prefix)

        # Get the coarse degree from PETSc options
        self.coarse_degree = PETSc.Options(options_prefix).getInt("mg_coarse_degree", default=1)

        # Construct a list with the elements we'll be using
        V = test.function_space()
        ele = V.ufl_element()
        elements = [ele]
        while True:
            try:
                ele_ = self.coarsen_element(ele)
                assert ele_.value_shape() == ele.value_shape()
                ele = ele_
            except ValueError:
                break
            elements.append(ele)

        sf = odm.getPointSF()
        section = odm.getDefaultSection()
        attach_hooks(pdm, level=len(elements)-1, sf=sf, section=section)
        # Now overwrite some routines on the DM
        pdm.setRefine(None)
        pdm.setCoarsen(self.coarsen)
        pdm.setCreateInterpolation(self.create_interpolation)
        # We need this for p-FAS
        pdm.setCreateInjection(self.create_injection)
        pdm.setSNESJacobian(_SNESContext.form_jacobian)
        pdm.setSNESFunction(_SNESContext.form_function)
        pdm.setKSPComputeOperators(_SNESContext.compute_operators)

        set_function_space(pdm, get_function_space(odm))

        parent = get_parent(odm)
        assert parent is not None
        add_hook(parent, setup=partial(push_parent, pdm, parent), teardown=partial(pop_parent, pdm, parent), call_setup=True)
        add_hook(parent, setup=partial(push_appctx, pdm, ctx), teardown=partial(pop_appctx, pdm, ctx), call_setup=True)

        self.ppc = self.configure_pmg(pc, pdm)
        self.ppc.setFromOptions()
        self.ppc.setUp()
Esempio n. 2
0
def coarsen_function_space(V, self, coefficient_mapping=None):
    if hasattr(V, "_coarse"):
        return V._coarse
    from firedrake.dmhooks import (get_transfer_operators, get_parent,
                                   push_transfer_operators,
                                   pop_transfer_operators, push_parent,
                                   pop_parent, add_hook)
    fine = V
    indices = []
    while True:
        if V.index is not None:
            indices.append(V.index)
        if V.component is not None:
            indices.append(V.component)
        if V.parent is not None:
            V = V.parent
        else:
            break

    mesh = self(V.mesh(), self)

    V = firedrake.FunctionSpace(mesh, V.ufl_element())

    for i in reversed(indices):
        V = V.sub(i)
    V._fine = fine
    fine._coarse = V

    # FIXME: This replicates some code from dmhooks.coarsen, but we
    # can't do things there because that code calls this code.

    # We need to move these operators over here because if we have
    # fieldsplits + MG with auxiliary coefficients on spaces other
    # than which we do the MG, dm.coarsen is never called, so the
    # hooks are not attached. Instead we just call (say) inject which
    # coarsens the functionspace.
    cdm = V.dm
    transfer = get_transfer_operators(fine.dm)
    parent = get_parent(fine.dm)
    try:
        add_hook(parent,
                 setup=partial(push_parent, cdm, parent),
                 teardown=partial(pop_parent, cdm, parent),
                 call_setup=True)
        add_hook(parent,
                 setup=partial(push_transfer_operators, cdm, transfer),
                 teardown=partial(pop_transfer_operators, cdm, transfer),
                 call_setup=True)
    except ValueError:
        # Not in an add_hooks context
        pass

    return V
Esempio n. 3
0
def coarsen_function(expr, self, coefficient_mapping=None):
    if coefficient_mapping is None:
        coefficient_mapping = {}
    new = coefficient_mapping.get(expr)
    if new is None:
        from firedrake.dmhooks import get_parent
        # Find potential parental mixed space (which will have an
        # appctx attached and hence a transfer manager if we're in a
        # solve)
        V = expr.function_space()
        while V.parent is not None:
            V = V.parent
        dm = get_parent(V.dm)
        _, _, inject = firedrake.dmhooks.get_transfer_operators(dm)
        V = self(expr.function_space(), self)
        new = firedrake.Function(V, name="coarse_%s" % expr.name())
        inject(expr, new)
    return new
Esempio n. 4
0
    def initialize(self, pc):
        # Make a new DM.
        # Hook up a (new) coarsen routine on that DM.
        # Make a new PC, of type MG.
        # Assign the DM to that PC.

        odm = pc.getDM()
        ctx = get_appctx(odm)

        test, trial = ctx.J.arguments()
        if test.function_space() != trial.function_space():
            raise NotImplementedError("test and trial spaces must be the same")

        # Construct a list with the elements we'll be using
        V = test.function_space()
        ele = V.ufl_element()
        elements = [ele]
        while True:
            try:
                ele_ = self.coarsen_element(ele)
                assert ele_.value_shape() == ele.value_shape()
                ele = ele_
            except ValueError:
                break
            elements.append(ele)

        pdm = PETSc.DMShell().create(comm=pc.comm)
        sf = odm.getPointSF()
        section = odm.getDefaultSection()
        attach_hooks(pdm, level=len(elements) - 1, sf=sf, section=section)
        # Now overwrite some routines on the DM
        pdm.setRefine(None)
        pdm.setCoarsen(self.coarsen)
        pdm.setCreateInterpolation(self.create_interpolation)
        pdm.setOptionsPrefix(pc.getOptionsPrefix() + "pmg_")
        set_function_space(pdm, get_function_space(odm))

        parent = get_parent(odm)
        assert parent is not None
        add_hook(parent,
                 setup=partial(push_parent, pdm, parent),
                 teardown=partial(pop_parent, pdm, parent),
                 call_setup=True)
        add_hook(parent,
                 setup=partial(push_appctx, pdm, ctx),
                 teardown=partial(pop_appctx, pdm, ctx),
                 call_setup=True)

        ppc = PETSc.PC().create(comm=pc.comm)
        ppc.setOptionsPrefix(pc.getOptionsPrefix() + "pmg_")
        ppc.setType("mg")
        ppc.setOperators(*pc.getOperators())
        ppc.setDM(pdm)
        ppc.incrementTabLevel(1, parent=pc)

        # PETSc unfortunately requires us to make an ugly hack.
        # We would like to use GMG for the coarse solve, at least
        # sometimes. But PETSc will use this p-DM's getRefineLevels()
        # instead of the getRefineLevels() of the MeshHierarchy to
        # decide how many levels it should use for PCMG applied to
        # the p-MG's coarse problem. So we need to set an option
        # for the user, if they haven't already; I don't know any
        # other way to get PETSc to know this at the right time.
        opts = PETSc.Options(pc.getOptionsPrefix() + "pmg_")
        if "mg_coarse_pc_mg_levels" not in opts:
            opts["mg_coarse_pc_mg_levels"] = odm.getRefineLevel() + 1

        ppc.setFromOptions()
        ppc.setUp()
        self.ppc = ppc
Esempio n. 5
0
    def coarsen(self, fdm, comm):
        fctx = get_appctx(fdm)
        test, trial = fctx.J.arguments()
        fV = test.function_space()
        fu = fctx._problem.u

        cele = self.coarsen_element(fV.ufl_element())
        cV = firedrake.FunctionSpace(fV.mesh(), cele)
        cdm = cV.dm

        cu = firedrake.Function(cV)
        interpolators = tuple(
            firedrake.Interpolator(fus, cus)
            for fus, cus in zip(fu.split(), cu.split()))

        def inject_state(interpolators):
            for interpolator in interpolators:
                interpolator.interpolate()

        parent = get_parent(fdm)
        assert parent is not None
        add_hook(parent,
                 setup=partial(push_parent, cdm, parent),
                 teardown=partial(pop_parent, cdm, parent),
                 call_setup=True)

        replace_d = {
            fu: cu,
            test: firedrake.TestFunction(cV),
            trial: firedrake.TrialFunction(cV)
        }
        cJ = replace(fctx.J, replace_d)
        cF = replace(fctx.F, replace_d)
        if fctx.Jp is not None:
            cJp = replace(fctx.Jp, replace_d)
        else:
            cJp = None

        cbcs = []
        for bc in fctx._problem.bcs:
            # Don't actually need the value, since it's only used for
            # killing parts of the matrix. This should be generalised
            # for p-FAS, if anyone ever wants to do that

            cV_ = cV
            for index in bc._indices:
                cV_ = cV_.sub(index)

            cbcs.append(
                firedrake.DirichletBC(cV_,
                                      firedrake.zero(cV_.shape),
                                      bc.sub_domain,
                                      method=bc.method))

        fcp = fctx._problem.form_compiler_parameters
        cproblem = firedrake.NonlinearVariationalProblem(
            cF,
            cu,
            cbcs,
            cJ,
            Jp=cJp,
            form_compiler_parameters=fcp,
            is_linear=fctx._problem.is_linear)

        cctx = _SNESContext(
            cproblem,
            fctx.mat_type,
            fctx.pmat_type,
            appctx=fctx.appctx,
            pre_jacobian_callback=fctx._pre_jacobian_callback,
            pre_function_callback=fctx._pre_function_callback,
            post_jacobian_callback=fctx._post_jacobian_callback,
            post_function_callback=fctx._post_function_callback,
            options_prefix=fctx.options_prefix,
            transfer_manager=fctx.transfer_manager)

        add_hook(parent,
                 setup=partial(push_appctx, cdm, cctx),
                 teardown=partial(pop_appctx, cdm, cctx),
                 call_setup=True)
        add_hook(parent,
                 setup=partial(inject_state, interpolators),
                 call_setup=True)

        cdm.setKSPComputeOperators(_SNESContext.compute_operators)
        cdm.setCreateInterpolation(self.create_interpolation)
        cdm.setOptionsPrefix(fdm.getOptionsPrefix())

        # If we're the coarsest grid of the p-hierarchy, don't
        # overwrite the coarsen routine; this is so that you can
        # use geometric multigrid for the p-coarse problem
        try:
            self.coarsen_element(cele)
            cdm.setCoarsen(self.coarsen)
        except ValueError:
            pass

        return cdm
Esempio n. 6
0
    def coarsen(self, fdm, comm):
        # Coarsen the _SNESContext of a DM fdm
        # return the coarse DM cdm of the coarse _SNESContext
        fctx = get_appctx(fdm)
        parent = get_parent(fdm)
        assert parent is not None

        test, trial = fctx.J.arguments()
        fV = test.function_space()
        cele = self.coarsen_element(fV.ufl_element())

        # Have we already done this?
        cctx = fctx._coarse
        if cctx is not None:
            cV = cctx.J.arguments()[0].function_space()
            if (cV.ufl_element() == cele) and (cV.mesh() == fV.mesh()):
                return cV.dm

        cV = firedrake.FunctionSpace(fV.mesh(), cele)
        cdm = cV.dm

        fproblem = fctx._problem
        fu = fproblem.u
        cu = firedrake.Function(cV)

        def coarsen_quadrature(df, Nf, Nc):
            # Coarsen the quadrature degree in a dictionary
            # such that the ratio of quadrature nodes to interpolation nodes (Nq+1)/(Nf+1) is preserved
            if isinstance(df, dict):
                Nq = df.get("quadrature_degree", None)
                if Nq is not None:
                    dc = dict(df)
                    dc["quadrature_degree"] = max(
                        2 * Nc + 1, ((Nq + 1) * (Nc + 1) + Nf) // (Nf + 1) - 1)
                    return dc
            return df

        def coarsen_form(form, Nf, Nc, replace_d):
            # Coarsen a form, by replacing the solution, test and trial functions, and
            # reconstructing each integral with a coarsened quadrature degree.
            # If form is not a Form, then return form.
            return Form([
                f.reconstruct(
                    metadata=coarsen_quadrature(f.metadata(), Nf, Nc))
                for f in replace(form, replace_d).integrals()
            ]) if isinstance(form, Form) else form

        def coarsen_bcs(fbcs):
            cbcs = []
            for bc in fbcs:
                cV_ = cV
                for index in bc._indices:
                    cV_ = cV_.sub(index)
                cbc_value = self.coarsen_bc_value(bc, cV_)
                if type(bc) == firedrake.DirichletBC:
                    cbcs.append(
                        firedrake.DirichletBC(cV_, cbc_value, bc.sub_domain))
                else:
                    raise NotImplementedError(
                        "Unsupported BC type, please get in touch if you need this"
                    )
            return cbcs

        Nf = PMGBase.max_degree(fV.ufl_element())
        Nc = PMGBase.max_degree(cV.ufl_element())

        # Replace dictionary with coarse state, test and trial functions
        replace_d = {
            fu: cu,
            test: firedrake.TestFunction(cV),
            trial: firedrake.TrialFunction(cV)
        }

        cF = coarsen_form(fctx.F, Nf, Nc, replace_d)
        cJ = coarsen_form(fctx.J, Nf, Nc, replace_d)
        cJp = coarsen_form(fctx.Jp, Nf, Nc, replace_d)
        fcp = coarsen_quadrature(fproblem.form_compiler_parameters, Nf, Nc)
        cbcs = coarsen_bcs(fproblem.bcs)

        # Coarsen the appctx: the user might want to provide solution-dependant expressions and forms
        cappctx = dict(fctx.appctx)
        for key in cappctx:
            val = cappctx[key]
            if isinstance(val, dict):
                cappctx[key] = coarsen_quadrature(val, Nf, Nc)
            elif isinstance(val, Expr):
                cappctx[key] = replace(val, replace_d)
            elif isinstance(val, Form):
                cappctx[key] = coarsen_form(val, Nf, Nc, replace_d)

        cmat_type = fctx.mat_type
        cpmat_type = fctx.pmat_type
        if Nc == self.coarse_degree:
            cmat_type = self.coarse_mat_type
            cpmat_type = self.coarse_mat_type
            if fcp is None:
                fcp = dict()
            fcp["mode"] = self.coarse_form_compiler_mode

        # Coarsen the problem and the _SNESContext
        cproblem = firedrake.NonlinearVariationalProblem(
            cF,
            cu,
            bcs=cbcs,
            J=cJ,
            Jp=cJp,
            form_compiler_parameters=fcp,
            is_linear=fproblem.is_linear)

        cctx = type(fctx)(cproblem,
                          cmat_type,
                          cpmat_type,
                          appctx=cappctx,
                          pre_jacobian_callback=fctx._pre_jacobian_callback,
                          pre_function_callback=fctx._pre_function_callback,
                          post_jacobian_callback=fctx._post_jacobian_callback,
                          post_function_callback=fctx._post_function_callback,
                          options_prefix=fctx.options_prefix,
                          transfer_manager=fctx.transfer_manager)

        # FIXME setting up the _fine attribute triggers gmg injection.
        # cctx._fine = fctx
        fctx._coarse = cctx

        add_hook(parent,
                 setup=partial(push_parent, cdm, parent),
                 teardown=partial(pop_parent, cdm, parent),
                 call_setup=True)
        add_hook(parent,
                 setup=partial(push_appctx, cdm, cctx),
                 teardown=partial(pop_appctx, cdm, cctx),
                 call_setup=True)

        cdm.setOptionsPrefix(fdm.getOptionsPrefix())
        cdm.setKSPComputeOperators(_SNESContext.compute_operators)
        cdm.setCreateInterpolation(self.create_interpolation)
        cdm.setCreateInjection(self.create_injection)

        # If we're the coarsest grid of the p-hierarchy, don't
        # overwrite the coarsen routine; this is so that you can
        # use geometric multigrid for the p-coarse problem
        try:
            self.coarsen_element(cele)
            cdm.setCoarsen(self.coarsen)
        except ValueError:
            pass

        # injection of the initial state
        def inject_state(mat):
            with cu.dat.vec_wo as xc, fu.dat.vec_ro as xf:
                mat.multTranspose(xf, xc)

        injection = self.create_injection(cdm, fdm)
        add_hook(parent,
                 setup=partial(inject_state, injection),
                 call_setup=True)

        # restrict the nullspace basis
        def coarsen_nullspace(coarse_V, mat, fine_nullspace):
            if isinstance(fine_nullspace, MixedVectorSpaceBasis):
                if mat.type == 'python':
                    mat = mat.getPythonContext()
                submats = [
                    mat.getNestSubMatrix(i, i) for i in range(len(coarse_V))
                ]
                coarse_bases = []
                for fs, submat, basis in zip(coarse_V, submats,
                                             fine_nullspace._bases):
                    if isinstance(basis, VectorSpaceBasis):
                        coarse_bases.append(
                            coarsen_nullspace(fs, submat, basis))
                    else:
                        coarse_bases.append(coarse_V.sub(basis.index))
                return MixedVectorSpaceBasis(coarse_V, coarse_bases)
            elif isinstance(fine_nullspace, VectorSpaceBasis):
                coarse_vecs = []
                for xf in fine_nullspace._petsc_vecs:
                    wc = firedrake.Function(coarse_V)
                    with wc.dat.vec_wo as xc:
                        mat.multTranspose(xf, xc)
                    coarse_vecs.append(wc)
                vsb = VectorSpaceBasis(coarse_vecs,
                                       constant=fine_nullspace._constant)
                vsb.orthonormalize()
                return vsb
            else:
                return fine_nullspace

        I, _ = self.create_interpolation(cdm, fdm)
        ises = cV._ises
        cctx._nullspace = coarsen_nullspace(cV, I, fctx._nullspace)
        cctx.set_nullspace(cctx._nullspace, ises, transpose=False, near=False)
        cctx._nullspace_T = coarsen_nullspace(cV, I, fctx._nullspace_T)
        cctx.set_nullspace(cctx._nullspace_T, ises, transpose=True, near=False)
        cctx._near_nullspace = coarsen_nullspace(cV, I, fctx._near_nullspace)
        cctx.set_nullspace(cctx._near_nullspace,
                           ises,
                           transpose=False,
                           near=True)
        return cdm
Esempio n. 7
0
def coarsen_snescontext(context, self, coefficient_mapping=None):
    if coefficient_mapping is None:
        coefficient_mapping = {}

    # Have we already done this?
    coarse = context._coarse
    if coarse is not None:
        return coarse

    problem = self(context._problem,
                   self,
                   coefficient_mapping=coefficient_mapping)
    appctx = context.appctx
    new_appctx = {}
    for k in sorted(appctx.keys()):
        v = appctx[k]
        if k != "state":
            # Constructor makes this one.
            try:
                new_appctx[k] = self(v,
                                     self,
                                     coefficient_mapping=coefficient_mapping)
            except CoarseningError:
                # Assume not something that needs coarsening (e.g. float)
                new_appctx[k] = v
    coarse = type(context)(problem,
                           mat_type=context.mat_type,
                           pmat_type=context.pmat_type,
                           appctx=new_appctx,
                           transfer_manager=context.transfer_manager)
    coarse._fine = context
    context._coarse = coarse

    # Now that we have the coarse snescontext, push it to the coarsened DMs
    # Otherwise they won't have the right transfer manager when they are
    # coarsened in turn
    from firedrake.dmhooks import get_appctx, push_appctx, pop_appctx
    from firedrake.dmhooks import add_hook, get_parent
    from itertools import chain
    for val in chain(coefficient_mapping.values(),
                     (bc.function_arg for bc in problem.bcs)):
        if isinstance(val, firedrake.function.Function):
            V = val.function_space()
            coarseneddm = V.dm
            parentdm = get_parent(context._problem.u.function_space().dm)

            # Now attach the hook to the parent DM
            if get_appctx(coarseneddm) is None:
                push_appctx(coarseneddm, coarse)
                teardown = partial(pop_appctx, coarseneddm, coarse)
                add_hook(parentdm, teardown=teardown)

    ises = problem.J.arguments()[0].function_space()._ises
    coarse._nullspace = self(context._nullspace,
                             self,
                             coefficient_mapping=coefficient_mapping)
    coarse.set_nullspace(coarse._nullspace, ises, transpose=False, near=False)
    coarse._nullspace_T = self(context._nullspace_T,
                               self,
                               coefficient_mapping=coefficient_mapping)
    coarse.set_nullspace(coarse._nullspace_T, ises, transpose=True, near=False)
    coarse._near_nullspace = self(context._near_nullspace,
                                  self,
                                  coefficient_mapping=coefficient_mapping)
    coarse.set_nullspace(coarse._near_nullspace,
                         ises,
                         transpose=False,
                         near=True)

    return coarse
Esempio n. 8
0
    def coarsen(self, fdm, comm):
        # Coarsen the _SNESContext of a DM fdm
        # return the coarse DM cdm of the coarse _SNESContext
        fctx = get_appctx(fdm)

        # Have we already done this?
        cctx = fctx._coarse
        if cctx is not None:
            return cctx.J.arguments()[0].function_space().dm

        parent = get_parent(fdm)
        assert parent is not None

        test, trial = fctx.J.arguments()
        fV = test.function_space()
        cele = self.coarsen_element(fV.ufl_element())
        cV = firedrake.FunctionSpace(fV.mesh(), cele)
        cdm = cV.dm

        def get_max_degree(ele):
            if isinstance(ele, MixedElement):
                return max(get_max_degree(sub) for sub in ele.sub_elements())
            else:
                N = ele.degree()
                try:
                    return max(N)
                except TypeError:
                    return N

        def coarsen_quadrature(df, Nf, Nc):
            # Coarsen the quadrature degree in a dictionary
            # such that the ratio of quadrature nodes to interpolation nodes (Nq+1)/(Nf+1) is preserved
            if isinstance(df, dict):
                Nq = df.get("quadrature_degree", None)
                if Nq is not None:
                    dc = dict(df)
                    dc["quadrature_degree"] = max(
                        2 * Nc + 1, ((Nq + 1) * (Nc + 1) + Nf) // (Nf + 1) - 1)
                    return dc
            return df

        def coarsen_form(form, Nf, Nc, replace_d):
            # Coarsen a form, by replacing the solution, test and trial functions, and
            # reconstructing each integral with a corsened quadrature degree.
            # If form is not a Form, then return form.
            return Form([
                f.reconstruct(
                    metadata=coarsen_quadrature(f.metadata(), Nf, Nc))
                for f in replace(form, replace_d).integrals()
            ]) if isinstance(form, Form) else form

        def coarsen_bcs(fbcs):
            cbcs = []
            for bc in fbcs:
                cV_ = cV
                for index in bc._indices:
                    cV_ = cV_.sub(index)
                cbc_value = self.coarsen_bc_value(bc, cV_)
                if type(bc) == firedrake.DirichletBC:
                    cbcs.append(
                        firedrake.DirichletBC(cV_,
                                              cbc_value,
                                              bc.sub_domain,
                                              method=bc.method))
                else:
                    raise NotImplementedError(
                        "Unsupported BC type, please get in touch if you need this"
                    )
            return cbcs

        Nf = get_max_degree(fV.ufl_element())
        Nc = get_max_degree(cV.ufl_element())

        fproblem = fctx._problem
        fu = fproblem.u
        cu = firedrake.Function(cV)

        # Replace dictionary with coarse state, test and trial functions
        replace_d = {
            fu: cu,
            test: firedrake.TestFunction(cV),
            trial: firedrake.TrialFunction(cV)
        }

        cF = coarsen_form(fctx.F, Nf, Nc, replace_d)
        cJ = coarsen_form(fctx.J, Nf, Nc, replace_d)
        cJp = coarsen_form(fctx.Jp, Nf, Nc, replace_d)
        fcp = coarsen_quadrature(fproblem.form_compiler_parameters, Nf, Nc)
        cbcs = coarsen_bcs(fproblem.bcs)

        # Coarsen the appctx: the user might want to provide solution-dependant expressions and forms
        cappctx = dict(fctx.appctx)
        for key in cappctx:
            val = cappctx[key]
            if isinstance(val, dict):
                cappctx[key] = coarsen_quadrature(val, Nf, Nc)
            elif isinstance(val, Expr):
                cappctx[key] = replace(val, replace_d)
            elif isinstance(val, Form):
                cappctx[key] = coarsen_form(val, Nf, Nc, replace_d)

        # Coarsen the problem and the _SNESContext
        cproblem = firedrake.NonlinearVariationalProblem(
            cF,
            cu,
            bcs=cbcs,
            J=cJ,
            Jp=cJp,
            form_compiler_parameters=fcp,
            is_linear=fproblem.is_linear)

        cctx = type(fctx)(cproblem,
                          fctx.mat_type,
                          fctx.pmat_type,
                          appctx=cappctx,
                          pre_jacobian_callback=fctx._pre_jacobian_callback,
                          pre_function_callback=fctx._pre_function_callback,
                          post_jacobian_callback=fctx._post_jacobian_callback,
                          post_function_callback=fctx._post_function_callback,
                          options_prefix=fctx.options_prefix,
                          transfer_manager=fctx.transfer_manager)

        # FIXME setting up the _fine attribute triggers gmg injection.
        # cctx._fine = fctx
        fctx._coarse = cctx

        add_hook(parent,
                 setup=partial(push_parent, cdm, parent),
                 teardown=partial(pop_parent, cdm, parent),
                 call_setup=True)
        add_hook(parent,
                 setup=partial(push_appctx, cdm, cctx),
                 teardown=partial(pop_appctx, cdm, cctx),
                 call_setup=True)

        cdm.setOptionsPrefix(fdm.getOptionsPrefix())
        cdm.setKSPComputeOperators(_SNESContext.compute_operators)
        cdm.setCreateInterpolation(self.create_interpolation)
        cdm.setCreateInjection(self.create_injection)

        # If we're the coarsest grid of the p-hierarchy, don't
        # overwrite the coarsen routine; this is so that you can
        # use geometric multigrid for the p-coarse problem
        try:
            self.coarsen_element(cele)
            cdm.setCoarsen(self.coarsen)
        except ValueError:
            pass

        # Injection of the initial state
        def inject_state(mat):
            with cu.dat.vec_wo as xc, fu.dat.vec_ro as xf:
                mat.multTranspose(xf, xc)

        injection = self.create_injection(cdm, fdm)
        add_hook(parent,
                 setup=partial(inject_state, injection),
                 call_setup=True)

        return cdm