Пример #1
0
    def wrapper(*args, **kwargs):
        """The project call performs an equation solve, and so it too must be annotated so that the
        adjoint and tangent linear models may be constructed automatically by pyadjoint.

        To disable the annotation of this function, just pass :py:data:`annotate=False`. This is useful in
        cases where the solve is known to be irrelevant or diagnostic for the purposes of the adjoint
        computation (such as projecting fields to other function spaces for the purposes of
        visualisation)."""

        annotate = annotate_tape(kwargs)
        with stop_annotating():
            output = project(*args, **kwargs)
        output = create_overloaded_object(output)

        if annotate:
            bcs = kwargs.pop("bcs", [])
            sb_kwargs = ProjectBlock.pop_kwargs(kwargs)
            block = ProjectBlock(args[0], args[1], output, bcs, **sb_kwargs)

            tape = get_working_tape()
            tape.add_block(block)

            block.add_output(output.block_variable)

        return output
Пример #2
0
        def wrapper(self, b, *args, **kwargs):
            ad_block_tag = kwargs.pop("ad_block_tag", None)
            annotate = annotate_tape(kwargs)

            if annotate:
                bcs = kwargs.get("bcs", [])
                if isinstance(
                        b, firedrake.Function
                ) and b.ufl_domain() != self.function_space().mesh():
                    block = SupermeshProjectBlock(b,
                                                  self.function_space(),
                                                  self,
                                                  bcs,
                                                  ad_block_tag=ad_block_tag)
                else:
                    block = ProjectBlock(b,
                                         self.function_space(),
                                         self,
                                         bcs,
                                         ad_block_tag=ad_block_tag)

                tape = get_working_tape()
                tape.add_block(block)

            with stop_annotating():
                output = project(self, b, *args, **kwargs)

            if annotate:
                block.add_output(output.create_block_variable())

            return output
Пример #3
0
        def wrapper(self, b, *args, **kwargs):

            annotate = annotate_tape(kwargs)

            with stop_annotating():
                output = project(self, b, *args, **kwargs)

            if annotate:
                bcs = kwargs.pop("bcs", [])
                block = ProjectBlock(b, self.function_space(), output, bcs)

                tape = get_working_tape()
                tape.add_block(block)

                block.add_output(output.create_block_variable())

            return output
Пример #4
0
    def wrapper(*args, **kwargs):
        """The project call performs an equation solve, and so it too must be annotated so that the
        adjoint and tangent linear models may be constructed automatically by pyadjoint.

        To disable the annotation of this function, just pass :py:data:`annotate=False`. This is useful in
        cases where the solve is known to be irrelevant or diagnostic for the purposes of the adjoint
        computation (such as projecting fields to other function spaces for the purposes of
        visualisation)."""

        annotate = annotate_tape(kwargs)
        if annotate:
            bcs = kwargs.get("bcs", [])
            sb_kwargs = ProjectBlock.pop_kwargs(kwargs)
            if isinstance(args[1], function.Function):
                # block should be created before project because output might also be an input that needs checkpointing
                output = args[1]
                V = output.function_space()
                block = ProjectBlock(args[0], V, output, bcs, **sb_kwargs)

        with stop_annotating():
            output = project(*args, **kwargs)

        if annotate:
            tape = get_working_tape()
            if not isinstance(args[1], function.Function):
                block = ProjectBlock(args[0], args[1], output, bcs,
                                     **sb_kwargs)
            tape.add_block(block)
            block.add_output(output.create_block_variable())

        return output