def wrapper(*args, **kwargs): """The project call performs an equation solve, and so it too must be annotated so that the adjoint and tangent linear models may be constructed automatically by pyadjoint. To disable the annotation of this function, just pass :py:data:`annotate=False`. This is useful in cases where the solve is known to be irrelevant or diagnostic for the purposes of the adjoint computation (such as projecting fields to other function spaces for the purposes of visualisation).""" annotate = annotate_tape(kwargs) if annotate: bcs = kwargs.get("bcs", []) sb_kwargs = ProjectBlock.pop_kwargs(kwargs) if isinstance(args[1], function.Function): # block should be created before project because output might also be an input that needs checkpointing output = args[1] V = output.function_space() block = ProjectBlock(args[0], V, output, bcs, **sb_kwargs) with stop_annotating(): output = project(*args, **kwargs) if annotate: tape = get_working_tape() if not isinstance(args[1], function.Function): block = ProjectBlock(args[0], args[1], output, bcs, **sb_kwargs) tape.add_block(block) block.add_output(output.create_block_variable()) return output
def wrapper(*args, **kwargs): """The project call performs an equation solve, and so it too must be annotated so that the adjoint and tangent linear models may be constructed automatically by pyadjoint. To disable the annotation of this function, just pass :py:data:`annotate=False`. This is useful in cases where the solve is known to be irrelevant or diagnostic for the purposes of the adjoint computation (such as projecting fields to other function spaces for the purposes of visualisation).""" annotate = annotate_tape(kwargs) with stop_annotating(): output = project(*args, **kwargs) output = create_overloaded_object(output) if annotate: bcs = kwargs.pop("bcs", []) sb_kwargs = ProjectBlock.pop_kwargs(kwargs) block = ProjectBlock(args[0], args[1], output, bcs, **sb_kwargs) tape = get_working_tape() tape.add_block(block) block.add_output(output.block_variable) return output
def wrapper(self, b, *args, **kwargs): ad_block_tag = kwargs.pop("ad_block_tag", None) annotate = annotate_tape(kwargs) if annotate: bcs = kwargs.get("bcs", []) if isinstance( b, firedrake.Function ) and b.ufl_domain() != self.function_space().mesh(): block = SupermeshProjectBlock(b, self.function_space(), self, bcs, ad_block_tag=ad_block_tag) else: block = ProjectBlock(b, self.function_space(), self, bcs, ad_block_tag=ad_block_tag) tape = get_working_tape() tape.add_block(block) with stop_annotating(): output = project(self, b, *args, **kwargs) if annotate: block.add_output(output.create_block_variable()) return output
def wrapper(self, b, *args, **kwargs): annotate = annotate_tape(kwargs) with stop_annotating(): output = project(self, b, *args, **kwargs) if annotate: bcs = kwargs.pop("bcs", []) block = ProjectBlock(b, self.function_space(), output, bcs) tape = get_working_tape() tape.add_block(block) block.add_output(output.create_block_variable()) return output