Example #1
0
    def testExperimentObjectiveUpdates(self):
        experiment = get_experiment_with_batch_trial()
        save_experiment(experiment)
        self.assertEqual(get_session().query(SQAMetric).count(),
                         len(experiment.metrics))

        # update objective
        # (should perform update in place)
        optimization_config = get_optimization_config()
        objective = get_objective()
        objective.minimize = True
        optimization_config.objective = objective
        experiment.optimization_config = optimization_config
        save_experiment(experiment)
        self.assertEqual(get_session().query(SQAMetric).count(),
                         len(experiment.metrics))

        # replace objective
        # (old one should become tracking metric)
        optimization_config.objective = Objective(metric=Metric(
            name="objective"))
        experiment.optimization_config = optimization_config
        save_experiment(experiment)
        self.assertEqual(get_session().query(SQAMetric).count(),
                         len(experiment.metrics))

        loaded_experiment = load_experiment(experiment.name)
        self.assertEqual(experiment, loaded_experiment)
Example #2
0
from unittest.mock import patch

import pandas as pd
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.data import Data
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_arms, get_experiment, get_objective


TEST_DATA = Data(
    df=pd.DataFrame(
        [
            {
                "arm_name": "0_0",
                "metric_name": get_objective().metric.name,
                "mean": 1.0,
                "sem": 2.0,
                "trial_index": 0,
            }
        ]
    )
)


class TrialTest(TestCase):
    def setUp(self):
        self.experiment = get_experiment()
        self.trial = self.experiment.new_trial()
        self.arm = get_arms()[0]
        self.trial.add_arm(self.arm)