コード例 #1
0
    def test_save_updated_trial(self):
        experiment, _ = self.init_experiment_and_generation_strategy(
            save_generation_strategy=False
        )

        exp = _load_experiment(
            experiment.name, decoder=self.with_db_settings.db_settings.decoder
        )
        trial = exp.new_trial()
        _save_or_update_trials(
            experiment=experiment,
            trials=[trial],
            encoder=self.with_db_settings.db_settings.encoder,
            decoder=self.with_db_settings.db_settings.decoder,
        )
        self.assertEqual(trial.status, TrialStatus.CANDIDATE)

        trial.mark_running(True)
        saved = self.with_db_settings._save_or_update_trial_in_db_if_possible(
            exp, trial
        )
        self.assertTrue(saved)
        exp = _load_experiment(
            experiment.name, decoder=self.with_db_settings.db_settings.decoder
        )
        self.assertEqual(len(exp.trials), 1)
        self.assertEqual(exp.trials[0].status, TrialStatus.RUNNING)
コード例 #2
0
    def test_save_new_trial(self):
        experiment, _ = self.init_experiment_and_generation_strategy(
            save_generation_strategy=False)

        exp = _load_experiment(
            experiment.name, decoder=self.with_db_settings.db_settings.decoder)
        trial = exp.new_trial()
        saved = self.with_db_settings._save_new_trial_to_db_if_possible(
            exp, trial)
        self.assertTrue(saved)
        exp = _load_experiment(
            experiment.name, decoder=self.with_db_settings.db_settings.decoder)
        self.assertEqual(len(exp.trials), 1)
        self.assertEqual(exp.trials[0].status, TrialStatus.CANDIDATE)
コード例 #3
0
 def test_save_experiment(self):
     experiment = self.get_random_experiment()
     saved = self.with_db_settings._save_experiment_to_db_if_possible(experiment)
     self.assertTrue(saved)
     loaded_experiment = _load_experiment(
         experiment.name, decoder=self.with_db_settings.db_settings.decoder
     )
     self.assertIsNotNone(loaded_experiment)
     self.assertEqual(experiment, loaded_experiment)
コード例 #4
0
 def test_update_experiment_properties_in_db(self):
     experiment, _ = self.init_experiment_and_generation_strategy(
         save_generation_strategy=False)
     experiment._properties["test_property"] = True
     self.with_db_settings._update_experiment_properties_in_db(
         experiment_with_updated_properties=experiment)
     loaded_experiment = _load_experiment(
         experiment.name, decoder=self.with_db_settings.db_settings.decoder)
     self.assertEqual(loaded_experiment._properties,
                      {"test_property": True})
コード例 #5
0
    def test_updated_trials_mini_batch(self):
        experiment, _ = self.init_experiment_and_generation_strategy(
            save_generation_strategy=False
        )

        # Check with 1 trial.
        trial = experiment.new_trial()
        self.assertIsNone(trial.db_id)
        self.with_db_settings._save_or_update_trials_in_db_if_possible(
            experiment=experiment,
            trials=[trial],
        )
        loaded_experiment = _load_experiment(
            experiment.name, decoder=self.with_db_settings.db_settings.decoder
        )
        self.assertEqual(
            loaded_experiment.trials.get(trial.index).status, TrialStatus.CANDIDATE
        )
        self.assertIsNotNone(trial.db_id)

        # Check with multiple trials, where their number % mini batch size is not 0.
        trials = [experiment.new_trial() for _ in range(5)]
        for t in trials:
            self.assertIsNone(t.db_id)

        trial.mark_running(no_runner_required=True)
        trials.append(trial)

        self.with_db_settings._save_or_update_trials_in_db_if_possible(
            experiment=experiment,
            trials=trials,
        )
        loaded_experiment = _load_experiment(
            experiment.name, decoder=self.with_db_settings.db_settings.decoder
        )
        # All trials except for the one we marked as running should be candidates.
        for t in trials:
            self.assertIsNotNone(t.db_id)
            if t.index != trial.index:
                self.assertEqual(t.status, TrialStatus.CANDIDATE)
            else:
                self.assertEqual(t.status, TrialStatus.RUNNING)
コード例 #6
0
ファイル: storage.py プロジェクト: yanpei18345156216/Ax
def load_experiment(name: str, db_settings: DBSettings) -> Experiment:
    """
    Load experiment from the db. Service API only supports `Experiment`.

    Args:
        name: Experiment name.
        db_settings: Defines behavior for loading/saving experiment to/from db.

    Returns:
        ax.core.Experiment: Loaded experiment.
    """
    init_engine_and_session_factory(creator=db_settings.creator, url=db_settings.url)
    experiment = _load_experiment(name, decoder=db_settings.decoder)
    if not isinstance(experiment, Experiment) or experiment.is_simple_experiment:
        raise ValueError("Service API only supports Experiment")
    return experiment
コード例 #7
0
def load_experiment(name: str, db_settings: DBSettings) -> Experiment:
    """
    Load experiment from the db. Service API only supports `Experiment`.

    Args:
        name: Experiment name.
        db_settings: Specifies decoder and xdb tier where experiment is stored.

    Returns:
        ax.core.Experiment: Loaded experiment.
    """
    initialize_db(db_settings)
    experiment = _load_experiment(name, db_settings.decoder)
    if not isinstance(experiment, Experiment):
        raise ValueError("Service API only supports Experiment")
    return experiment
コード例 #8
0
ファイル: with_db_settings_base.py プロジェクト: mengwa41/Ax
    def _load_experiment_and_generation_strategy(
        self, experiment_name: str
    ) -> Tuple[Optional[Experiment], Optional[GenerationStrategy]]:
        """Loads experiment and its corresponding generation strategy from database
        if DB settings are set on this `WithDBSettingsBase` instance.

        Args:
            experiment_name: Name of the experiment to load, used as unique
                identifier by which to find the experiment.

        Returns:
            - Tuple of `None` and `None` if `DBSettings` are set and no experiment
              exists by the given name.
            - Tuple of `Experiment` and `None` if experiment exists but does not
              have a generation strategy attached to it.
            - Tuple of `Experiment` and `GenerationStrategy` if experiment exists
              and has a generation strategy attached to it.
        """
        if not self.db_settings_set:
            raise ValueError("Cannot load from DB in absence of DB settings.")

        start_time = time.time()
        experiment = _load_experiment(experiment_name,
                                      decoder=self.db_settings.decoder)
        if not isinstance(experiment,
                          Experiment) or experiment.is_simple_experiment:
            raise ValueError("Service API only supports `Experiment`.")
        logger.debug(
            f"Loaded experiment {experiment_name} in "
            f"{_round_floats_for_logging(time.time() - start_time)} seconds.")

        try:
            start_time = time.time()
            generation_strategy = _load_generation_strategy_by_experiment_name(
                experiment_name=experiment_name,
                decoder=self.db_settings.decoder)
            logger.debug(
                f"Loaded generation strategy for experiment {experiment_name} in "
                f"{_round_floats_for_logging(time.time() - start_time)} seconds."
            )
        except ValueError as err:
            if "does not have a generation strategy" in str(err):
                return experiment, None
            raise  # `ValueError` here could signify more than just absence of GS.
        return experiment, generation_strategy
コード例 #9
0
    def test_update_reduced_state_generator_runs(self):
        experiment, generation_strategy = self.init_experiment_and_generation_strategy(
            save_generation_strategy=True
        )

        trials = [experiment.new_trial() for _ in range(5)]
        grs = []
        for t in trials:
            gr = generation_strategy.gen(experiment)
            grs.append(gr)
            t.add_generator_run(gr)

        self.with_db_settings._save_or_update_trials_and_generation_strategy_if_possible(  # noqa E501
            experiment=experiment,
            trials=trials,
            generation_strategy=generation_strategy,
            new_generator_runs=grs,
            reduce_state_generator_runs=True,
        )

        loaded_experiment = _load_experiment(
            experiment.name, decoder=self.with_db_settings.db_settings.decoder
        )

        # Only the last trial's generator run should have large model attributes
        for idx, trial in loaded_experiment.trials.items():
            for key in [f"_{attr.key}" for attr in GR_LARGE_MODEL_ATTRS]:
                if idx < len(loaded_experiment.trials) - 1:
                    self.assertIsNone(getattr(trial.generator_run, key))
                else:
                    self.assertIsNotNone(getattr(trial.generator_run, key))

        loaded_generation_strategy = _load_generation_strategy_by_experiment_name(
            experiment.name, decoder=self.with_db_settings.db_settings.decoder
        )

        # Only the last generator run should have large model attributes
        for idx, gr in enumerate(loaded_generation_strategy._generator_runs):
            for key in [f"_{attr.key}" for attr in GR_LARGE_MODEL_ATTRS]:
                if idx < len(loaded_generation_strategy._generator_runs) - 1:
                    self.assertIsNone(getattr(gr, key))
                else:
                    self.assertIsNotNone(getattr(gr, key))
コード例 #10
0
def load_experiment(name: str, db_settings: DBSettings) -> Experiment:
    """
    Load experiment from the db. Service API only supports `Experiment`.

    Args:
        name: Experiment name.
        db_settings: Defines behavior for loading/saving experiment to/from db.

    Returns:
        ax.core.Experiment: Loaded experiment.
    """
    init_engine_and_session_factory(creator=db_settings.creator,
                                    url=db_settings.url)
    start_time = time.time()
    experiment = _load_experiment(name, decoder=db_settings.decoder)
    if not isinstance(experiment,
                      Experiment) or experiment.is_simple_experiment:
        raise ValueError("Service API only supports `Experiment`.")
    logger.debug(
        f"Loaded experiment {name} in "
        f"{_round_floats_for_logging(time.time() - start_time)} seconds.")
    return experiment
コード例 #11
0
    def _load_experiment_and_generation_strategy(
        self,
        experiment_name: str,
        reduced_state: bool = False,
    ) -> Tuple[Optional[Experiment], Optional[GenerationStrategy]]:
        """Loads experiment and its corresponding generation strategy from database
        if DB settings are set on this `WithDBSettingsBase` instance.

        Args:
            experiment_name: Name of the experiment to load, used as unique
                identifier by which to find the experiment.
            reduced_state: Whether to load experiment and generation strategy
                with a slightly reduced state (without abandoned arms on experiment
                and model state on each generator run in experiment and generation
                strategy; last generator run on generation strategy will still
                have its model state).

        Returns:
            - Tuple of `None` and `None` if `DBSettings` are set and no experiment
              exists by the given name.
            - Tuple of `Experiment` and `None` if experiment exists but does not
              have a generation strategy attached to it.
            - Tuple of `Experiment` and `GenerationStrategy` if experiment exists
              and has a generation strategy attached to it.
        """
        if not self.db_settings_set:
            raise ValueError("Cannot load from DB in absence of DB settings.")

        logger.info(
            "Loading experiment and generation strategy (with reduced state: "
            f"{reduced_state})...")
        start_time = time.time()
        experiment = _load_experiment(
            experiment_name,
            decoder=self.db_settings.decoder,
            reduced_state=reduced_state,
            load_trials_in_batches_of_size=LOADING_MINI_BATCH_SIZE,
        )
        if not isinstance(experiment, Experiment):
            raise ValueError("Service API only supports `Experiment`.")
        logger.info(
            f"Loaded experiment {experiment_name} in "
            f"{_round_floats_for_logging(time.time() - start_time)} seconds, "
            f"loading trials in mini-batches of {LOADING_MINI_BATCH_SIZE}.")

        try:
            start_time = time.time()
            generation_strategy = _load_generation_strategy_by_experiment_name(
                experiment_name=experiment_name,
                decoder=self.db_settings.decoder,
                experiment=experiment,
                reduced_state=reduced_state,
            )
            logger.info(
                f"Loaded generation strategy for experiment {experiment_name} in "
                f"{_round_floats_for_logging(time.time() - start_time)} seconds."
            )
        except ValueError as err:
            if "does not have a generation strategy" in str(err):
                return experiment, None
            raise  # `ValueError` here could signify more than just absence of GS.
        return experiment, generation_strategy