Esempio n. 1
0
def optimizations_and_results(
    smirnoff_force_field,
) -> Tuple[List[Optimization], List[OptimizationResult]]:

    optimization_1 = create_optimization(
        "project-1",
        "study-1",
        "optimization-1",
        [
            create_evaluator_target("evaluator-target-1", ["data-set-1"]),
            create_recharge_target("recharge-target-1", ["qc-data-set-1"]),
        ],
    )
    optimization_1.name = "Optimization 1"
    optimization_1.force_field = ForceField.from_openff(smirnoff_force_field)
    optimization_2 = create_optimization(
        "project-1",
        "study-1",
        "optimization-2",
        [
            create_evaluator_target("evaluator-target-1", ["data-set-1"]),
            create_recharge_target("recharge-target-1", ["qc-data-set-1"]),
        ],
    )
    optimization_2.force_field = ForceField.from_openff(smirnoff_force_field)
    optimization_2.name = "Optimization 2"

    smirnoff_force_field.get_parameter_handler("vdW").parameters["[#6:1]"].epsilon *= 2
    smirnoff_force_field.get_parameter_handler("vdW").parameters["[#6:1]"].sigma *= 3

    optimization_result_1 = create_optimization_result(
        "project-1",
        "study-1",
        "optimization-1",
        ["evaluator-target-1"],
        ["recharge-target-1"],
    )
    optimization_result_1.refit_force_field = ForceField.from_openff(
        smirnoff_force_field
    )

    smirnoff_force_field.get_parameter_handler("vdW").parameters["[#6:1]"].epsilon /= 4
    smirnoff_force_field.get_parameter_handler("vdW").parameters["[#6:1]"].sigma /= 6

    optimization_result_2 = create_optimization_result(
        "project-1",
        "study-1",
        "optimization-2",
        ["evaluator-target-1"],
        ["recharge-target-1"],
    )
    optimization_result_2.refit_force_field = ForceField.from_openff(
        smirnoff_force_field
    )

    return (
        [optimization_1, optimization_2],
        [optimization_result_1, optimization_result_2],
    )
def test_prepare_restart_finished(caplog):

    optimization = create_optimization(
        "project-1",
        "study-1",
        "optimization-1",
        [
            create_recharge_target("recharge-target-1", ["qc-data-set-1"]),
            create_recharge_target("recharge-target-2", ["qc-data-set-1"]),
        ],
    )

    with temporary_cd():

        directories = [
            os.path.join("optimize.tmp", "recharge-target-1", "iter_0000"),
            os.path.join("optimize.tmp", "recharge-target-1", "iter_0001"),
            os.path.join("optimize.tmp", "recharge-target-2", "iter_0000"),
            os.path.join("optimize.tmp", "recharge-target-2", "iter_0001"),
        ]

        for directory in directories:

            os.makedirs(directory)

            for file_name in [
                    "mvals.txt", "force-field.offxml", "objective.p"
            ]:

                with open(os.path.join(directory, file_name), "w") as file:
                    file.write("")

        assert len(glob(os.path.join("optimize.tmp", "recharge-target-1",
                                     "*"))) == 2
        assert len(glob(os.path.join("optimize.tmp", "recharge-target-2",
                                     "*"))) == 2

        with caplog.at_level(logging.INFO):
            _prepare_restart(optimization)

        assert len(glob(os.path.join("optimize.tmp", "recharge-target-1",
                                     "*"))) == 2
        assert len(glob(os.path.join("optimize.tmp", "recharge-target-2",
                                     "*"))) == 2

        assert (
            "2 iterations had previously been completed. The optimization will be "
            "restarted from iteration 0002") in caplog.text
Esempio n. 3
0
def mock_target(tmpdir) -> Tuple[Optimization, RechargeTarget, str]:
    """Create a mock recharge target directory which is populated with a dummy
    set of results.

    Returns
    -------
        A tuple of the parent optimization, the mock target and the path to the
        directory in which the files were created.
    """

    with temporary_cd(str(tmpdir)):

        # Mock the target to analyze.
        target = create_recharge_target("recharge-target-1", ["qc-data-set-1"])

        optimization = create_optimization("project-1", "study-1",
                                           "optimization-1", [target])
        optimization.analysis_environments = [
            ChemicalEnvironment.Alkane,
            ChemicalEnvironment.Alcohol,
        ]

        # Create a dummy set of residuals.
        with open("residuals.json", "w") as file:
            json.dump({"C": 9.0, "CO": 4.0}, file)

        lp_dump({"X": 1.0}, "objective.p")

    return optimization, target, str(tmpdir)
Esempio n. 4
0
def test_analysis_n_iteration(monkeypatch, force_field):
    """Test that the correction exception is raised in the case where a refit
    force field is found but no target outputs are."""

    optimization = create_optimization(
        "project-1",
        "study-1",
        "optimization-1",
        [
            create_evaluator_target("evaluator-target-1", ["data-set-1"]),
            create_recharge_target("recharge-target-1", ["qc-data-set-1"]),
        ],
    )
    optimization.force_field = force_field

    with temporary_cd():

        # Save the mock optimization file.
        with open("optimization.json", "w") as file:
            file.write(optimization.json())

        # Mock successfully reading a refit force field.
        monkeypatch.setattr(OptimizationAnalysisFactory,
                            "_load_refit_force_field", lambda: force_field)

        with pytest.raises(RuntimeError) as error_info:
            OptimizationAnalysisFactory.analyze(True)

        assert "No iteration results could be found" in str(error_info.value)
Esempio n. 5
0
def test_analysis_missing_result(monkeypatch, force_field):
    """Test that the correction exception is raised in the case where a the
    expected results of a target are missing."""

    optimization = create_optimization(
        "project-1",
        "study-1",
        "optimization-1",
        [
            create_evaluator_target("evaluator-target-1", ["data-set-1"]),
            create_recharge_target("recharge-target-1", ["qc-data-set-1"]),
        ],
    )
    optimization.force_field = force_field

    with temporary_cd():

        # Save the expected results files.
        os.makedirs(os.path.join("result", "optimize"))

        for target in optimization.targets:
            os.makedirs(os.path.join("targets", target.id))
            os.makedirs(os.path.join("optimize.tmp", target.id, "iter_0000"))

            lp_dump(
                {"X": 1.0},
                os.path.join("optimize.tmp", target.id, "iter_0000",
                             "objective.p"),
            )

        with open("optimization.json", "w") as file:
            file.write(optimization.json())

        monkeypatch.setattr(OptimizationAnalysisFactory,
                            "_load_refit_force_field", lambda: force_field)

        # Mock a missing target result.
        monkeypatch.setattr(
            EvaluatorAnalysisFactory,
            "analyze",
            lambda *args, **kwargs: EvaluatorTargetResult(
                objective_function=1.0, statistic_entries=[]),
        )
        monkeypatch.setattr(
            RechargeAnalysisFactory,
            "analyze",
            lambda *args, **kwargs: None,
        )

        with pytest.raises(RuntimeError) as error_info:
            OptimizationAnalysisFactory.analyze(True)

        assert "The results of the recharge-target-1 target could not be found" in str(
            error_info.value)
Esempio n. 6
0
    def create_model(cls, include_children=False, index=1):

        optimization = create_optimization(
            "project-1",
            "study-1",
            f"optimization-{index}",
            targets=[
                create_evaluator_target("evaluator-target-1", ["data-set-1"]),
                create_recharge_target("recharge-target-1", ["qc-data-set-1"]),
            ],
        )

        return optimization
def optimization(force_field) -> Optimization:
    optimization = create_optimization(
        "project-1",
        "study-1",
        "optimization-1",
        [
            create_evaluator_target("evaluator-target-1", ["data-set-1"]),
            create_recharge_target("recharge-target-1", ["qc-data-set-1"]),
        ],
    )
    optimization.force_field = force_field

    return optimization
def test_run_command(restart: bool, create_save: bool, runner, monkeypatch):

    from nonbonded.cli.projects.optimization import run

    monkeypatch.setattr(run, "_remove_previous_files", lambda: print("REMOVE"))
    monkeypatch.setattr(run, "_prepare_restart",
                        lambda *args: print("PREPARE"))
    monkeypatch.setattr(subprocess, "check_call", lambda *args, **kwargs: None)

    optimization = create_optimization(
        "project-1",
        "study-1",
        "optimization-1",
        [create_recharge_target("recharge-target-1", ["qc-data-set-1"])],
    )

    # Save a copy of the result model.
    with temporary_cd():

        with open("optimization.json", "w") as file:
            file.write(optimization.json())

        if create_save:

            with open("optimize.sav", "w") as file:
                file.write("")

        arguments = [] if not restart else ["--restart", True]

        result = runner.invoke(run_command(), arguments)

        if restart and create_save:
            assert "REMOVE" not in result.output
            assert "PREPARE" in result.output

        elif restart and not create_save:
            assert "REMOVE" in result.output
            assert "PREPARE" not in result.output

        if not restart:
            assert "REMOVE" in result.output
            assert "PREPARE" not in result.output

    if result.exit_code != 0:
        raise result.exception
    def test_generate_recharge_target(self, requests_mock):

        qc_data_set = create_qc_data_set("qc-data-set-1")
        mock_get_qc_data_set(requests_mock, qc_data_set)

        target = create_recharge_target("recharge-target-1", [qc_data_set.id])

        with temporary_cd():

            OptimizationInputFactory._generate_recharge_target(target, None)

            with open("training-set.json") as file:
                training_entries = json.load(file)

            assert training_entries == qc_data_set.entries

            with open("grid-settings.json") as file:
                assert file.read() == target.grid_settings.json()
Esempio n. 10
0
def commit_optimization(
    db: Session,
) -> Tuple[Project, Study, Optimization, DataSetCollection,
           QCDataSetCollection]:
    """Commits a new project and study to the current session and appends an
    empty optimization onto it. Additionally, this function commits two data sets
    to the session to use as the training set.

    Parameters
    ----------
    db
        The current data base session.
    """

    training_set = commit_data_set_collection(db)
    training_set_ids = [x.id for x in training_set.data_sets]

    qc_data_set = commit_qc_data_set_collection(db)
    qc_data_set_ids = [x.id for x in qc_data_set.data_sets]

    study = create_study("project-1", "study-1")
    study.optimizations = [
        create_optimization(
            "project-1",
            "study-1",
            "optimization-1",
            [
                create_evaluator_target("evaluator-target-1",
                                        training_set_ids),
                create_recharge_target("recharge-target-1", qc_data_set_ids),
            ],
        )
    ]

    project = create_project(study.project_id)
    project.studies = [study]

    db_project = ProjectCRUD.create(db, project)
    db.add(db_project)
    db.commit()

    project = ProjectCRUD.db_to_model(db_project)
    return project, study, study.optimizations[0], training_set, qc_data_set
def test_plot_target_rmse(tmpdir):

    initial_result = RechargeTargetResult(
        objective_function=0.5,
        statistic_entries=[
            Statistic(
                statistic_type=StatisticType.RMSE,
                value=1.0,
                lower_95_ci=0.95,
                upper_95_ci=1.05,
                category="Alcohol",
            )
        ],
    )
    final_result = RechargeTargetResult(
        objective_function=0.5,
        statistic_entries=[
            Statistic(
                statistic_type=StatisticType.RMSE,
                value=0.5,
                lower_95_ci=0.4,
                upper_95_ci=0.6,
                category="Alcohol",
            )
        ],
    )

    figures = plot_target_rmse(
        [create_recharge_target("target-1", ["data-set-1"])] * 2,
        [initial_result, final_result],
        ["Initial", "Final"],
    )

    assert "esp" in figures
    figure = figures["esp"]

    assert len(figure.subplots) == 1
    assert len(figure.subplots[0].traces) == 2

    assert figure is not None
    assert figure.to_plotly() is not None
Esempio n. 12
0
def valid_optimization_kwargs(valid_sub_study_kwargs):

    force_field = ForceField(inner_content=" ")
    parameters = [Parameter(handler_type=" ", smirks=" ", attribute_name=" ")]

    return {
        **valid_sub_study_kwargs,
        "max_iterations":
        1,
        "engine":
        ForceBalance(priors={" ": 1.0}),
        "targets": [
            create_evaluator_target("evaluator-target", ["data-set-1"]),
            create_recharge_target("recharge-target", ["qc-data-set-1"]),
        ],
        "force_field":
        force_field,
        "parameters_to_train":
        parameters,
        "analysis_environments": [],
    }
from nonbonded.tests.utilities.factory import (
    create_evaluator_target,
    create_optimization,
    create_recharge_target,
)


@pytest.mark.parametrize(
    "optimization, server_config, expected_raises",
    [
        (
            create_optimization(
                "project-1",
                "study-1",
                "optimization-1",
                [create_recharge_target("recharge-target", ["qc-data-set-1"])],
            ),
            None,
            does_not_raise(),
        ),
        (
            create_optimization(
                "project-1",
                "study-1",
                "optimization-1",
                [create_evaluator_target("evaluator-target", ["data-set-1"])],
            ),
            None,
            pytest.raises(RuntimeError),
        ),
        (
Esempio n. 14
0
def test_analysis(monkeypatch, force_field, dummy_conda_env):

    optimization = create_optimization(
        "project-1",
        "study-1",
        "optimization-1",
        [
            create_evaluator_target("evaluator-target-1", ["data-set-1"]),
            create_recharge_target("recharge-target-1", ["qc-data-set-1"]),
        ],
    )
    optimization.force_field = force_field

    with temporary_cd(os.path.dirname(dummy_conda_env)):

        # Save the expected results files.
        os.makedirs(os.path.join("result", "optimize"))

        for target in optimization.targets:
            os.makedirs(os.path.join("targets", target.id))

            os.makedirs(os.path.join("optimize.tmp", target.id, "iter_0000"))
            os.makedirs(os.path.join("optimize.tmp", target.id, "iter_0001"))

            # Add enough output files to make it look like only one full iteration has
            # finished.
            lp_dump(
                {"X": 1.0},
                os.path.join("optimize.tmp", target.id, "iter_0000",
                             "objective.p"),
            )

        lp_dump(
            {"X": 1.0},
            os.path.join("optimize.tmp", optimization.targets[0].id,
                         "iter_0001", "objective.p"),
        )

        with open("optimization.json", "w") as file:
            file.write(optimization.json())

        optimization.force_field.to_openff().to_file(
            os.path.join("result", "optimize", "force-field.offxml"))

        # Mock the already tested functions.
        monkeypatch.setattr(OptimizationAnalysisFactory,
                            "_load_refit_force_field", lambda: force_field)
        monkeypatch.setattr(
            EvaluatorAnalysisFactory,
            "analyze",
            lambda *args, **kwargs: EvaluatorTargetResult(
                objective_function=1.0, statistic_entries=[]),
        )
        monkeypatch.setattr(
            RechargeAnalysisFactory,
            "analyze",
            lambda *args, **kwargs: RechargeTargetResult(
                objective_function=1.0, statistic_entries=[]),
        )

        OptimizationAnalysisFactory.analyze(True)

        for target in optimization.targets:

            assert os.path.isfile(
                os.path.join("analysis", target.id, "iteration-0.json"))
            assert not os.path.isfile(
                os.path.join("analysis", target.id, "iteration-1.json"))

        result = OptimizationResult.parse_file(
            os.path.join("analysis", "optimization-results.json"))

        assert len(result.target_results) == 1
        assert all(target.id in result.target_results[0]
                   for target in optimization.targets)
        assert result.refit_force_field.inner_content == force_field.inner_content
class TestOptimizationInputFactory:
    def test_prepare_force_field(self, optimization):
        """Test that the correct cosmetic attributes are attached to the FF, especially
        in the special case of BCC handlers."""

        optimization.parameters_to_train.append(
            Parameter(
                handler_type="ChargeIncrementModel",
                smirks="[#6:1]-[#6:2]",
                attribute_name="charge_increment1",
            ))
        optimization.parameters_to_train.append(
            Parameter(
                handler_type="vdW",
                smirks=None,
                attribute_name="scale14",
            ))

        with temporary_cd():

            OptimizationInputFactory._prepare_force_field(optimization, None)

            assert os.path.isfile(
                os.path.join("forcefield", "force-field.offxml"))

            off_force_field = OFFForceField(
                os.path.join("forcefield", "force-field.offxml"),
                allow_cosmetic_attributes=True,
            )

        vdw_handler = off_force_field.get_parameter_handler("vdW")
        assert vdw_handler._parameterize == "scale14"

        assert len(vdw_handler.parameters) == 1
        parameter = vdw_handler.parameters["[#6:1]"]
        assert parameter._parameterize == "epsilon, sigma"

        bcc_handler = off_force_field.get_parameter_handler(
            "ChargeIncrementModel")
        assert len(bcc_handler.parameters) == 1
        parameter = bcc_handler.parameters["[#6:1]-[#6:2]"]
        assert len(parameter.charge_increment) == 1
        assert parameter._parameterize == "charge_increment1"

    def test_generate_force_balance_input(self, optimization):

        with temporary_cd():
            OptimizationInputFactory._generate_force_balance_input(
                optimization)
            assert os.path.isfile("optimize.in")

    @pytest.mark.parametrize("allow_reweighting", [False, True])
    def test_generate_request_options_default(self, allow_reweighting):

        training_set = create_data_set("data-set-1", 1)

        target = create_evaluator_target("evaluator-target-1", ["data-set-1"])
        target.allow_direct_simulation = True
        target.allow_reweighting = allow_reweighting

        request_options = OptimizationInputFactory._generate_request_options(
            target, training_set.to_evaluator())

        if allow_reweighting:
            assert request_options.calculation_layers == [
                "ReweightingLayer",
                "SimulationLayer",
            ]
        else:
            assert request_options.calculation_layers == ["SimulationLayer"]

        assert request_options.calculation_schemas == UNDEFINED

    def test_generate_request_options(self):

        training_set = create_data_set("data-set-1", 1)
        target = create_evaluator_target("evaluator-target-1",
                                         [training_set.id])

        target.allow_direct_simulation = True
        target.allow_reweighting = True
        target.n_molecules = 512
        target.n_effective_samples = 10

        request_options = OptimizationInputFactory._generate_request_options(
            target, training_set.to_evaluator())

        assert request_options.calculation_layers == [
            "ReweightingLayer",
            "SimulationLayer",
        ]

        assert request_options.calculation_schemas != UNDEFINED

        expected_simulation_schema = Density.default_simulation_schema(
            n_molecules=512)
        expected_reweighting_schema = Density.default_reweighting_schema(
            n_effective_samples=10)

        assert (
            request_options.calculation_schemas["Density"]
            ["SimulationLayer"].json() == expected_simulation_schema.json())
        assert (
            request_options.calculation_schemas["Density"]
            ["ReweightingLayer"].json() == expected_reweighting_schema.json())

    def test_generate_evaluator_target(self, requests_mock):

        data_set = create_data_set("data-set-1")
        mock_get_data_set(requests_mock, data_set)

        target = create_evaluator_target("evaluator-target-1", [data_set.id])

        with temporary_cd():

            OptimizationInputFactory._generate_evaluator_target(
                target, 8000, None)

            assert os.path.isfile("training-set.json")
            off_data_set = PhysicalPropertyDataSet.from_json(
                "training-set.json")
            assert off_data_set.json() == data_set.to_evaluator().json()

            assert os.path.isfile("options.json")

    def test_generate_recharge_target(self, requests_mock):

        qc_data_set = create_qc_data_set("qc-data-set-1")
        mock_get_qc_data_set(requests_mock, qc_data_set)

        target = create_recharge_target("recharge-target-1", [qc_data_set.id])

        with temporary_cd():

            OptimizationInputFactory._generate_recharge_target(target, None)

            with open("training-set.json") as file:
                training_entries = json.load(file)

            assert training_entries == qc_data_set.entries

            with open("grid-settings.json") as file:
                assert file.read() == target.grid_settings.json()

    @pytest.mark.parametrize(
        "target",
        [
            create_evaluator_target("evaluator-target-1", ["data-set-1"]),
            create_recharge_target("recharge-target-1", ["qc-data-set-1"]),
        ],
    )
    def test_generate_target(self, target, caplog, monkeypatch):

        monkeypatch.setattr(
            OptimizationInputFactory,
            "_generate_evaluator_target",
            lambda *args: logging.info("EvaluatorTarget"),
        )
        monkeypatch.setattr(
            OptimizationInputFactory,
            "_generate_recharge_target",
            lambda *args: logging.info("RechargeTarget"),
        )

        with caplog.at_level(logging.INFO):

            with temporary_cd():
                OptimizationInputFactory._generate_target(target, 8000, None)
                assert os.path.isdir(os.path.join("targets", target.id))

        assert target.__class__.__name__ in caplog.text

    def test_retrieve_results(self, optimization, requests_mock):

        result = create_optimization_result(
            optimization.project_id,
            optimization.study_id,
            optimization.id,
            ["evaluator-target-1"],
            [],
        )
        mock_get_optimization_result(requests_mock, result)

        with temporary_cd():

            OptimizationInputFactory._retrieve_results(optimization)

            stored_result = OptimizationResult.parse_file(
                os.path.join("analysis", "optimization-results.json"))
            assert stored_result.json() == result.json()

    def test_generate(self, optimization, monkeypatch):

        logging.basicConfig(level=logging.INFO)

        # Mock the already tested functions
        monkeypatch.setattr(OptimizationInputFactory, "_prepare_force_field",
                            lambda *args: None)
        monkeypatch.setattr(
            OptimizationInputFactory,
            "_generate_force_balance_input",
            lambda *args: None,
        )
        monkeypatch.setattr(OptimizationInputFactory, "_generate_target",
                            lambda *args: None)
        monkeypatch.setattr(OptimizationInputFactory,
                            "_generate_submission_script", lambda *args: None)
        monkeypatch.setattr(OptimizationInputFactory, "_retrieve_results",
                            lambda *args: None)

        with temporary_cd():

            OptimizationInputFactory.generate(optimization, "env", "01:00",
                                              "lilac-local", 8000, 1, True)
Esempio n. 16
0
def create_dependencies(db: Session, dependencies: List[str]):
    """Create any dependencies such as parent studies, projects, or data sets and
    commit them to the database.

    Parameters
    ----------
    db
        The current database session.
    dependencies
        The required dependencies.
    """

    project = None
    data_set_ids = []
    qc_data_set_ids = []

    if "data-set" in dependencies:
        data_set_ids.append("data-set-1")
    if "qc-data-set" in dependencies:
        qc_data_set_ids.append("qc-data-set-1")

    for data_set_id in data_set_ids:
        data_set = create_data_set(data_set_id)
        db_data_set = DataSetCRUD.create(db, data_set)
        db.add(db_data_set)

    for qc_data_set_id in qc_data_set_ids:
        qc_data_set = create_qc_data_set(qc_data_set_id)
        db_qc_data_set = QCDataSetCRUD.create(db, qc_data_set)
        db.add(db_qc_data_set)

    db.commit()

    if ("project" in dependencies or "study" in dependencies
            or "evaluator-target" in dependencies
            or "recharge-target" in dependencies
            or "benchmark" in dependencies):
        project = create_project("project-1")

    if ("study" in dependencies or "evaluator-target" in dependencies
            or "recharge-target" in dependencies
            or "benchmark" in dependencies):
        project.studies = [create_study(project.id, "study-1")]

    if "evaluator-target" in dependencies or "recharge-target" in dependencies:

        targets = []

        if "evaluator-target" in dependencies:
            targets.append(
                create_evaluator_target("evaluator-target-1", ["data-set-1"]))
        if "recharge-target" in dependencies:
            targets.append(
                create_recharge_target("recharge-target-1", ["qc-data-set-1"]))

        optimization = create_optimization(project.id, project.studies[0].id,
                                           "optimization-1", targets)

        project.studies[0].optimizations = [optimization]

    if "benchmark" in dependencies:
        benchmark = create_benchmark(
            project.id,
            project.studies[0].id,
            "benchmark-1",
            ["data-set-1"],
            None,
            create_force_field(),
        )

        project.studies[0].benchmarks = [benchmark]

    if project is not None:
        db_project = ProjectCRUD.create(db, project)
        db.add(db_project)
        db.commit()
Esempio n. 17
0
def optimization_model_perturbations():

    updated_engine_delete_prior = ForceBalance(
        priors={"vdW/Atom/epsilon": 0.1})
    updated_engine_update_prior = ForceBalance(priors={
        "vdW/Atom/epsilon": 0.2,
        "vdW/Atom/sigma": 2.0
    }, )
    updated_engine_add_prior = ForceBalance(priors={
        "vdW/Atom/epsilon": 0.1,
        "vdW/Atom/sigma": 2.0,
        "vdW/Atom/r_min": 2.0
    }, )

    invalid_evaluator_target = create_evaluator_target("evaluator-target-1",
                                                       ["data-set-999"])
    invalid_recharge_target = create_recharge_target("recharge-target-1",
                                                     ["qc-data-set-999"])

    return [
        ({
            "name": "updated"
        }, lambda db: [], does_not_raise()),
        ({
            "description": "updated"
        }, lambda db: [], does_not_raise()),
        ({
            "max_iterations": 999
        }, lambda db: [], does_not_raise()),
        (
            {
                "analysis_environments": [ChemicalEnvironment.Hydroxy]
            },
            lambda db: [],
            does_not_raise(),
        ),
        # Test updating the force field.
        (
            {
                "force_field": create_force_field("updated")
            },
            lambda db: [
                "updated" in db.query(models.ForceField.inner_content).first()[
                    0],
                db.query(models.ForceField.id).count() == 1,
            ],
            does_not_raise(),
        ),
        # Test updating the parameters to train.
        (
            {
                "parameters_to_train": [
                    Parameter(handler_type="vdW",
                              smirks="[#6:1]",
                              attribute_name="epsilon"),
                ]
            },
            lambda db: [db.query(models.Parameter.id).count() == 1],
            does_not_raise(),
        ),
        (
            {
                "parameters_to_train": [
                    Parameter(handler_type="vdW",
                              smirks="[#6:1]",
                              attribute_name="epsilon"),
                    Parameter(handler_type="vdW",
                              smirks="[#6:1]",
                              attribute_name="sigma"),
                    Parameter(handler_type="vdW",
                              smirks="[#1:1]",
                              attribute_name="sigma"),
                ]
            },
            lambda db: [db.query(models.Parameter.id).count() == 3],
            does_not_raise(),
        ),
        # Test updating an engine's priors
        (
            {
                "engine": updated_engine_delete_prior
            },
            lambda db: [db.query(models.ForceBalancePrior.id).count() == 1],
            does_not_raise(),
        ),
        (
            {
                "engine": updated_engine_update_prior
            },
            lambda db: [db.query(models.ForceBalancePrior.id).count() == 2],
            does_not_raise(),
        ),
        (
            {
                "engine": updated_engine_add_prior
            },
            lambda db: [db.query(models.ForceBalancePrior.id).count() == 3],
            does_not_raise(),
        ),
        # Test deleting a target
        (
            {
                "targets": [
                    create_evaluator_target("evaluator-target-1",
                                            ["data-set-1"])
                ]
            },
            lambda db: [
                db.query(models.EvaluatorTarget.id).count() == 1,
                db.query(models.RechargeTarget.id).count() == 0,
            ],
            does_not_raise(),
        ),
        (
            {
                "targets": [
                    create_recharge_target("recharge-target-1",
                                           ["qc-data-set-1"])
                ]
            },
            lambda db: [
                db.query(models.EvaluatorTarget.id).count() == 0,
                db.query(models.RechargeTarget.id).count() == 1,
            ],
            does_not_raise(),
        ),
        # Test adding a target
        (
            {
                "targets": [
                    create_evaluator_target("evaluator-target-1",
                                            ["data-set-1"]),
                    create_evaluator_target("evaluator-target-2",
                                            ["data-set-1"]),
                    create_recharge_target("recharge-target-1",
                                           ["qc-data-set-1"]),
                ]
            },
            lambda db: [
                db.query(models.EvaluatorTarget.id).count() == 2,
                db.query(models.RechargeTarget.id).count() == 1,
            ],
            does_not_raise(),
        ),
        (
            {
                "targets": [
                    create_evaluator_target("evaluator-target-1",
                                            ["data-set-1"]),
                    create_recharge_target("recharge-target-1",
                                           ["qc-data-set-1"]),
                    create_recharge_target("recharge-target-2",
                                           ["qc-data-set-1"]),
                ]
            },
            lambda db: [
                db.query(models.EvaluatorTarget.id).count() == 1,
                db.query(models.RechargeTarget.id).count() == 2,
            ],
            does_not_raise(),
        ),
        # Test invalidly updating a target's training set
        (
            {
                "targets": [invalid_evaluator_target]
            },
            lambda db: [],
            pytest.raises(DataSetNotFoundError),
        ),
        (
            {
                "targets": [invalid_recharge_target]
            },
            lambda db: [],
            pytest.raises(QCDataSetNotFoundError),
        ),
    ]
def test_prepare_restart_unfinished(partial_restart, caplog):

    optimization = create_optimization(
        "project-1",
        "study-1",
        "optimization-1",
        [
            create_recharge_target("recharge-target-1", ["qc-data-set-1"]),
            create_recharge_target("recharge-target-2", ["qc-data-set-1"]),
        ],
    )

    with temporary_cd():

        directories = [
            os.path.join("optimize.tmp", "recharge-target-1", "iter_0000"),
            os.path.join("optimize.tmp", "recharge-target-2", "iter_0000"),
            os.path.join("optimize.tmp", "recharge-target-1", "iter_0001"),
            os.path.join("optimize.tmp", "recharge-target-2", "iter_0001"),
        ]

        for index, directory in enumerate(directories):

            os.makedirs(directory)

            expected_files = ["mvals.txt"]

            if index < 3:
                expected_files.append("objective.p")
            if index < (3 if not partial_restart else 4):
                expected_files.append("force-field.offxml")

            for file_name in expected_files:

                with open(os.path.join(directory, file_name), "w") as file:
                    file.write("")

        assert len(glob(os.path.join("optimize.tmp", "recharge-target-1",
                                     "*"))) == 2
        assert len(glob(os.path.join("optimize.tmp", "recharge-target-2",
                                     "*"))) == 2

        with caplog.at_level(logging.INFO):
            _prepare_restart(optimization)

        expected_directories = 2 if partial_restart else 1

        assert (len(
            glob(os.path.join("optimize.tmp", "recharge-target-1",
                              "*"))) == expected_directories)
        assert (len(
            glob(os.path.join("optimize.tmp", "recharge-target-2",
                              "*"))) == expected_directories)

        if not partial_restart:
            assert (
                f"Removing the {directories[2]} directory which was not expected to be "
                f"present") in caplog.text
            assert (
                f"Removing the {directories[3]} directory which was not expected to be "
                f"present") in caplog.text
        else:
            assert "Removing the" not in caplog.text

        assert (
            "1 iterations had previously been completed. The optimization will be "
            f"restarted from iteration {'0000' if not partial_restart else '0001'}"
        ) in caplog.text