Beispiel #1
0
def pytest_runtest_teardown(item, nextitem):
    """Clear Thetis caches after running a test"""
    from firedrake.tsfc_interface import TSFCKernel
    from pyop2.global_kernel import GlobalKernel
    from pyadjoint import get_working_tape

    # disgusting hack, clear the Class-Cached objects in PyOP2 and
    # Firedrake, otherwise these will never be collected.  The Kernels
    # get very big with bendy on.
    TSFCKernel._cache.clear()
    GlobalKernel._cache.clear()
    # clear the adjoint tape, so subsequent tests don't interfere
    get_working_tape().clear_tape()
Beispiel #2
0
def handle_taping():
    """
    **Disclaimer: copied from
        firedrake/tests/regression/test_adjoint_operators.py
    """
    yield
    import pyadjoint

    tape = pyadjoint.get_working_tape()
    tape.clear_tape()
Beispiel #3
0
def run_action(params_loc=None):
    tick = time.time()

    ### Clean up other module references ###
    mods_to_remove = []
    for k in sys.modules.keys():
        # if ("windse" in k):
        if ("windse" in k or "dolfin_adjoint" in k or "fenics_adjoint" in k):
            mods_to_remove.append(k)

    for i in range(len(mods_to_remove)):
        del sys.modules[mods_to_remove[i]]

    ### Clean tape if available ###
    tape = get_working_tape()
    if tape is not None:
        tape.clear_tape()

    ### Import fresh version of windse ###
    import windse
    try:
        from .driver_functions import SetupSimulation
    except:
        from driver_functions import SetupSimulation

    ### Setup everything ###
    params, problem, solver = SetupSimulation(params_loc)

    ### run the solver ###
    solver.Solve()

    ### Perform Optimization ###
    if params.performing_opt_calc:
        opt=windse.Optimizer(solver)
        if params["optimization"]["gradient"] or params["general"]["debug_mode"]:
            opt.Gradient()
        if params["optimization"]["taylor_test"]:
            opt.TaylorTest()
        if params["optimization"]["optimize"]:
            opt.Optimize()


    tock = time.time()
    runtime = tock-tick
    if params.rank == 0:
        print("Run Complete: {:1.2f} s".format(runtime))

    params.comm.Barrier()

    return runtime
Beispiel #4
0
 def fin():
     tape = pyadjoint.get_working_tape()
     if tape is not None:
         assert len(tape.get_blocks()) == 0
Beispiel #5
0
    def get_solve_blocks(self, field, subinterval=0, has_adj_sol=True):
        """
        Get all blocks of the tape corresponding to
        solve steps for prognostic solution ``field``
        on a given ``subinterval``.
        """
        from firedrake.adjoint.solving import get_solve_blocks
        from pyadjoint import get_working_tape

        # Get all blocks
        blocks = get_working_tape().get_blocks()
        if len(blocks) == 0:
            self.warning("Tape has no blocks!")
            return blocks

        # Restrict to solve blocks
        solve_blocks = get_solve_blocks()
        if len(solve_blocks) == 0:
            self.warning("Tape has no solve blocks!")
            return solve_blocks

        # Slice solve blocks by field
        solve_blocks = [
            block for block in solve_blocks
            if block.tag is not None and field in block.tag
        ]
        if len(solve_blocks) == 0:
            self.warning(f"No solve blocks associated with field '{field}'.\n"
                         "Has ad_block_tag been used correctly?")
            return solve_blocks
        self.debug(
            f"Field '{field}' on subinterval {subinterval} has {len(solve_blocks)} solve blocks"
        )

        # Default adjoint solution to zero, rather than None
        if has_adj_sol:
            if all(block.adj_sol is None for block in solve_blocks):
                self.warning(
                    "No block has an adjoint solution. Has the adjoint equation been solved?"
                )
            for block in solve_blocks:
                if block.adj_sol is None:
                    block.adj_sol = firedrake.Function(
                        self.function_spaces[field][subinterval], name=field)

        # Check FunctionSpaces are consistent across solve blocks
        element = solve_blocks[0].function_space.ufl_element()
        for block in solve_blocks:
            if element != block.function_space.ufl_element():
                raise ValueError(
                    f"Solve block list for field {field} contains mismatching elements"
                    f" ({element} vs. {block.function_space.ufl_element()})")

        # Check the number of timesteps divides the number of solve blocks
        num_timesteps = self.time_partition[subinterval].num_timesteps
        ratio = len(solve_blocks) / num_timesteps
        if not np.isclose(np.round(ratio), ratio):
            raise ValueError(
                f"Number of timesteps for field '{field}' does not divide number of solve"
                f" blocks ({num_timesteps} vs. {len(solve_blocks)}). If you are trying to"
                " use a multi-stage Runge-Kutta method, then this is not supported."
            )
        return solve_blocks
Beispiel #6
0
def test_adjoint_same_mesh(problem, qoi_type, debug=False):
    """
    Check that `solve_adjoint` gives the same
    result when applied on one or two subintervals.

    :arg problem: string denoting the test case of choice
    :arg qoi_type: is the QoI evaluated at the end time
        or as a time integral?
    :kwarg debug: toggle debugging mode
    """
    from firedrake_adjoint import pyadjoint

    # Debugging
    if debug:
        set_log_level(DEBUG)

    # Imports
    pyrint(f"\n--- Setting up {problem} test case with {qoi_type} QoI\n")
    test_case = importlib.import_module(problem)
    end_time = test_case.end_time
    steady = test_case.steady
    if steady:
        assert test_case.dt_per_export == 1
        assert np.isclose(end_time / test_case.dt, 1.0)
    if "solid_body_rotation" in problem:
        end_time /= 4  # Reduce testing time
    elif steady and qoi_type == "time_integrated":
        pytest.skip("n/a for steady case")

    # Partition time interval and create MeshSeq
    time_partition = TimePartition(
        end_time,
        1,
        test_case.dt,
        test_case.fields,
        timesteps_per_export=test_case.dt_per_export,
    )
    mesh_seq = AdjointMeshSeq(
        time_partition,
        test_case.mesh,
        get_function_spaces=test_case.get_function_spaces,
        get_initial_condition=test_case.get_initial_condition,
        get_form=test_case.get_form,
        get_solver=test_case.get_solver,
        get_qoi=test_case.get_qoi,
        get_bcs=test_case.get_bcs,
        qoi_type=qoi_type,
        steady=steady,
    )

    # Solve forward and adjoint without solve_adjoint
    pyrint("\n--- Adjoint solve on 1 subinterval using pyadjoint\n")
    ic = mesh_seq.initial_condition
    controls = [pyadjoint.Control(value) for key, value in ic.items()]
    sols = mesh_seq.solver(0, ic)
    qoi = mesh_seq.get_qoi(sols, 0)
    J = mesh_seq.J if qoi_type == "time_integrated" else qoi()
    m = pyadjoint.enlisting.Enlist(controls)
    tape = pyadjoint.get_working_tape()
    with pyadjoint.stop_annotating():
        with tape.marked_nodes(m):
            tape.evaluate_adj(markings=True)
    # FIXME: Using mixed Functions as Controls not correct
    J_expected = float(J)

    # Get expected adjoint solutions and values
    adj_sols_expected = {}
    adj_values_expected = {}
    for field, fs in mesh_seq._fs.items():
        solve_blocks = mesh_seq.get_solve_blocks(field)
        fwd_old_idx = mesh_seq.get_lagged_dependency_index(
            field, 0, solve_blocks)
        adj_sols_expected[field] = solve_blocks[0].adj_sol.copy(deepcopy=True)
        if not steady:
            adj_values_expected[field] = Function(
                fs[0],
                val=solve_blocks[0]._dependencies[fwd_old_idx].adj_value)

    # Loop over having one or two subintervals
    for N in range(1, 2 if steady else 3):
        pl = "" if N == 1 else "s"
        pyrint(f"\n--- Adjoint solve on {N} subinterval{pl} using pyroteus\n")

        # Solve forward and adjoint on each subinterval
        time_partition = TimePartition(
            end_time,
            N,
            test_case.dt,
            test_case.fields,
            timesteps_per_export=test_case.dt_per_export,
        )
        mesh_seq = AdjointMeshSeq(
            time_partition,
            test_case.mesh,
            get_function_spaces=test_case.get_function_spaces,
            get_initial_condition=test_case.get_initial_condition,
            get_form=test_case.get_form,
            get_solver=test_case.get_solver,
            get_qoi=test_case.get_qoi,
            get_bcs=test_case.get_bcs,
            qoi_type=qoi_type,
        )
        solutions = mesh_seq.solve_adjoint(get_adj_values=not steady,
                                           test_checkpoint_qoi=True)

        # Check quantities of interest match
        if not np.isclose(J_expected, mesh_seq.J):
            raise ValueError(
                f"QoIs do not match ({J_expected} vs. {mesh_seq.J})")

        # Check adjoint solutions at initial time match
        for field in time_partition.fields:
            adj_sol_expected = adj_sols_expected[field]
            adj_sol_computed = solutions[field].adjoint[0][0]
            err = errornorm(adj_sol_expected,
                            adj_sol_computed) / norm(adj_sol_expected)
            if not np.isclose(err, 0.0):
                raise ValueError(
                    f"Adjoint solutions do not match at t=0 (error {err:.4e}.)"
                )

        # Check adjoint actions at initial time match
        if not steady:
            for field in time_partition.fields:
                adj_value_expected = adj_values_expected[field]
                adj_value_computed = solutions[field].adj_value[0][0]
                err = errornorm(adj_value_expected,
                                adj_value_computed) / norm(adj_value_expected)
                if not np.isclose(err, 0.0):
                    raise ValueError(
                        f"Adjoint values do not match at t=0 (error {err:.4e}.)"
                    )