コード例 #1
0
ファイル: load.py プロジェクト: facebook/Ax
def _get_trials_sqa(
    experiment_id: int,
    trial_sqa_class: Type[SQATrial],
    load_trials_in_batches_of_size: Optional[int] = None,
    trials_query_options: Optional[List[Any]] = None,
) -> List[SQATrial]:
    """Obtains SQLAlchemy trial objects for given experiment ID from DB,
    optionally in mini-batches and with specified query options.
    """
    with session_scope() as session:
        query = session.query(
            trial_sqa_class.id).filter_by(experiment_id=experiment_id)
        trial_db_ids = query.all()
        trial_db_ids = [db_id_tuple[0] for db_id_tuple in trial_db_ids]

    if len(trial_db_ids) == 0:
        return []

    batch_size = (len(trial_db_ids) if load_trials_in_batches_of_size is None
                  else load_trials_in_batches_of_size)

    sqa_trials = []
    for i in range(ceil(len(trial_db_ids) / batch_size)):
        mini_batch_db_ids = trial_db_ids[batch_size * i:batch_size * (i + 1)]
        with session_scope() as session:
            query = session.query(trial_sqa_class).filter(
                trial_sqa_class.id.in_(mini_batch_db_ids)  # pyre-ignore[16]
            )

            if trials_query_options is not None:
                query = query.options(*trials_query_options)

            sqa_trials.extend(query.all())

    return sqa_trials
コード例 #2
0
def _save_experiment(experiment: Experiment, encoder: Encoder) -> None:
    """Save experiment, using given Encoder instance.

    1) Convert Ax object to SQLAlchemy object.
    2) Determine if there is an existing experiment with that name in the DB.
    3) If not, create a new one.
    4) If so, update the old one.
        The update works by merging the new SQLAlchemy object into the
        existing SQLAlchemy object, and then letting SQLAlchemy handle the
        actual DB updates.
    """
    # Convert user-facing class to SQA outside of session scope to avoid timeouts
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        existing_sqa_experiment = (session.query(exp_sqa_class).filter_by(
            name=experiment.name).one_or_none())
    encoder.validate_experiment_metadata(
        experiment, existing_sqa_experiment=existing_sqa_experiment)
    new_sqa_experiment = encoder.experiment_to_sqa(experiment)

    if existing_sqa_experiment is not None:
        # Update the SQA object outside of session scope to avoid timeouts.
        # This object is detached from the session, but contains a database
        # identity marker, so when we do `session.add` below, SQA knows to
        # perform an update rather than an insert.
        existing_sqa_experiment.update(new_sqa_experiment)
        new_sqa_experiment = existing_sqa_experiment

    with session_scope() as session:
        session.add(new_sqa_experiment)
コード例 #3
0
ファイル: test_sqa_store.py プロジェクト: bitnot/Ax
    def testParameterValidation(self):
        sqa_parameter = SQAParameter(
            domain_type=DomainType.RANGE,
            parameter_type=ParameterType.FLOAT,
            name="foobar",
        )
        with self.assertRaises(ValueError):
            with session_scope() as session:
                session.add(sqa_parameter)

        sqa_parameter.experiment_id = 0
        with session_scope() as session:
            session.add(sqa_parameter)
        with self.assertRaises(ValueError):
            sqa_parameter.generator_run_id = 0
            with session_scope() as session:
                session.add(sqa_parameter)

        sqa_parameter = SQAParameter(
            domain_type=DomainType.RANGE,
            parameter_type=ParameterType.FLOAT,
            name="foobar",
            generator_run_id=0,
        )
        with session_scope() as session:
            session.add(sqa_parameter)
        with self.assertRaises(ValueError):
            sqa_parameter.experiment_id = 0
            with session_scope() as session:
                session.add(sqa_parameter)
コード例 #4
0
ファイル: test_sqa_store.py プロジェクト: bitnot/Ax
    def testParameterConstraintValidation(self):
        sqa_parameter_constraint = SQAParameterConstraint(
            bound=0, constraint_dict={}, type=ParameterConstraintType.LINEAR)
        with self.assertRaises(ValueError):
            with session_scope() as session:
                session.add(sqa_parameter_constraint)

        sqa_parameter_constraint.experiment_id = 0
        with session_scope() as session:
            session.add(sqa_parameter_constraint)
        with self.assertRaises(ValueError):
            sqa_parameter_constraint.generator_run_id = 0
            with session_scope() as session:
                session.add(sqa_parameter_constraint)

        sqa_parameter_constraint = SQAParameterConstraint(
            bound=0,
            constraint_dict={},
            type=ParameterConstraintType.LINEAR,
            generator_run_id=0,
        )
        with session_scope() as session:
            session.add(sqa_parameter_constraint)
        with self.assertRaises(ValueError):
            sqa_parameter_constraint.experiment_id = 0
            with session_scope() as session:
                session.add(sqa_parameter_constraint)
コード例 #5
0
ファイル: test_sqa_store.py プロジェクト: bitnot/Ax
    def testMetricValidation(self):
        sqa_metric = SQAMetric(
            name="foobar",
            intent=MetricIntent.OBJECTIVE,
            metric_type=METRIC_REGISTRY[BraninMetric],
        )
        with self.assertRaises(ValueError):
            with session_scope() as session:
                session.add(sqa_metric)

        sqa_metric.experiment_id = 0
        with session_scope() as session:
            session.add(sqa_metric)
        with self.assertRaises(ValueError):
            sqa_metric.generator_run_id = 0
            with session_scope() as session:
                session.add(sqa_metric)

        sqa_metric = SQAMetric(
            name="foobar",
            intent=MetricIntent.OBJECTIVE,
            metric_type=METRIC_REGISTRY[BraninMetric],
            generator_run_id=0,
        )
        with session_scope() as session:
            session.add(sqa_metric)
        with self.assertRaises(ValueError):
            sqa_metric.experiment_id = 0
            with session_scope() as session:
                session.add(sqa_metric)
コード例 #6
0
def _update_trials(experiment: Experiment, trials: List[BaseTrial],
                   encoder: Encoder) -> None:
    """Update trials and attach data."""
    trial_sqa_class = encoder.config.class_to_sqa_class[Trial]
    trial_indices = [trial.index for trial in trials]
    obj_to_sqa = []
    with session_scope() as session:
        experiment_id = _get_experiment_id(experiment=experiment,
                                           encoder=encoder,
                                           session=session)
        existing_trials = (
            session.query(trial_sqa_class).filter_by(
                experiment_id=experiment_id).filter(
                    trial_sqa_class.index.in_(trial_indices))  # pyre-ignore
            .all())

    trial_index_to_existing_trial = {
        trial.index: trial
        for trial in existing_trials
    }

    updated_sqa_trials, new_sqa_data = [], []
    for trial in trials:
        existing_trial = trial_index_to_existing_trial.get(trial.index)
        if existing_trial is None:
            raise ValueError(
                f"Trial {trial.index} is not attached to the experiment.")

        new_sqa_trial, _obj_to_sqa = encoder.trial_to_sqa(trial)
        obj_to_sqa.extend(_obj_to_sqa)
        existing_trial.update(new_sqa_trial)
        updated_sqa_trials.append(existing_trial)

        data, ts = experiment.lookup_data_for_trial(trial_index=trial.index)
        if ts != -1:
            sqa_data = encoder.data_to_sqa(data=data,
                                           trial_index=trial.index,
                                           timestamp=ts)
            obj_to_sqa.append((data, sqa_data))
            sqa_data.experiment_id = experiment_id
            new_sqa_data.append(sqa_data)

    with session_scope() as session:
        session.add_all(updated_sqa_trials)
        session.add_all(new_sqa_data)
        session.flush()

    _set_db_ids(obj_to_sqa=obj_to_sqa)
コード例 #7
0
def _load_generation_strategy_by_experiment_name(
    experiment_name: str, decoder: Decoder
) -> GenerationStrategy:
    """Load a generation strategy attached to an experiment specified by a name,
    using given Decoder instance.

    1) Get SQLAlchemy object from DB.
    2) Convert to corresponding Ax object.
    """
    exp_sqa_class = decoder.config.class_to_sqa_class[Experiment]
    gs_sqa_class = decoder.config.class_to_sqa_class[GenerationStrategy]
    with session_scope() as session:
        gs_sqa = (
            session.query(gs_sqa_class)
            .join(exp_sqa_class.generation_strategy)  # pyre-ignore[16]
            # pyre-fixme[16]: `SQABase` has no attribute `name`.
            .filter(exp_sqa_class.name == experiment_name)
            .one_or_none()
        )
    if gs_sqa is None:
        raise ValueError(
            f"Experiment {experiment_name} does not have a generation strategy "
            "attached to it."
        )
    return decoder.generation_strategy_from_sqa(gs_sqa=gs_sqa)
コード例 #8
0
ファイル: load.py プロジェクト: isabella232/Ax
def _get_experiment_sqa_reduced_state(experiment_name: str,
                                      decoder: Decoder) -> SQAExperiment:
    """Obtains most of the SQLAlchemy experiment object from DB, with some attributes
    (model state on generator runs, abandoned arms) omitted. Used for loading
    large experiments, in cases where model state history is not required.
    """
    exp_sqa_class = cast(Type[SQAExperiment],
                         decoder.config.class_to_sqa_class[Experiment])
    with session_scope() as session:
        sqa_experiment = (session.query(exp_sqa_class).filter_by(
            name=experiment_name).options(
                lazyload("trials.generator_runs.parameters"),
                lazyload("trials.generator_runs.parameter_constraints"),
                lazyload("trials.generator_runs.metrics"),
                lazyload("trials.abandoned_arms"),
                defaultload(exp_sqa_class.trials).defaultload(
                    "generator_runs").defer("model_kwargs"),
                defaultload(exp_sqa_class.trials).defaultload(
                    "generator_runs").defer("bridge_kwargs"),
                defaultload(exp_sqa_class.trials).defaultload(
                    "generator_runs").defer("model_state_after_gen"),
                defaultload(exp_sqa_class.trials).defaultload(
                    "generator_runs").defer("gen_metadata"),
            ).one_or_none())
        if sqa_experiment is None:
            raise ValueError(f"Experiment '{experiment_name}' not found.")
    return sqa_experiment
コード例 #9
0
def _load_generation_strategy_by_id(gs_id: int, decoder: Decoder) -> GenerationStrategy:
    """Finds a generation strategy stored by a given ID and restores it."""
    with session_scope() as session:
        gs_sqa = session.query(SQAGenerationStrategy).filter_by(id=gs_id).one_or_none()
        if gs_sqa is None:
            raise ValueError(f"Generation strategy with ID #{gs_id} not found.")
    return decoder.generation_strategy_from_sqa(gs_sqa=gs_sqa)
コード例 #10
0
ファイル: save.py プロジェクト: Vilashcj/Ax
def _update_generation_strategy(
    generation_strategy: GenerationStrategy,
    generator_runs: List[GeneratorRun],
    encoder: Encoder,
) -> None:
    """Update generation strategy's current step and attach generator runs."""
    gs_sqa_class = encoder.config.class_to_sqa_class[GenerationStrategy]

    gs_id = generation_strategy._db_id
    if gs_id is None:
        raise ValueError(
            "GenerationStrategy must be saved before being updated.")

    with session_scope() as session:
        experiment_id = _get_experiment_id(
            experiment=generation_strategy.experiment,
            encoder=encoder,
            session=session)
        gs_sqa = session.query(gs_sqa_class).get(gs_id)
        gs_sqa.curr_index = generation_strategy._curr.index  # pyre-fixme
        gs_sqa.experiment_id = experiment_id  # pyre-ignore

        session.add(gs_sqa)
        for generator_run in generator_runs:
            gr_sqa = encoder.generator_run_to_sqa(generator_run=generator_run)
            gr_sqa.generation_strategy_id = gs_id
            session.add(gr_sqa)
コード例 #11
0
ファイル: save.py プロジェクト: jshuadvd/Ax
def _save_generation_strategy(generation_strategy: GenerationStrategy,
                              encoder: Encoder) -> int:
    # If the generation strategy has not yet generated anything, there will be no
    # experiment set on it.
    if generation_strategy._experiment is None:
        experiment_id = None
    else:
        # Experiment was set on the generation strategy, so need to check whether
        # if has been saved and create a relationship b/w GS and experiment if so.
        experiment_id = _get_experiment_id(
            experiment=generation_strategy._experiment, encoder=encoder)

    gs_sqa = encoder.generation_strategy_to_sqa(
        generation_strategy=generation_strategy, experiment_id=experiment_id)

    with session_scope() as session:
        if generation_strategy._db_id is None:
            session.add(gs_sqa)
            session.flush()  # Ensures generation strategy id is set.
            generation_strategy._db_id = gs_sqa.id
        else:
            gs_sqa_class = encoder.config.class_to_sqa_class[
                GenerationStrategy]
            existing_gs_sqa = session.query(gs_sqa_class).get(
                generation_strategy._db_id)
            existing_gs_sqa.update(gs_sqa)
            # our update logic ignores foreign keys, i.e. fields ending in _id,
            # because we want SQLAlchemy to handle those relationships for us
            # however, generation_strategy.experiment_id is an exception, so we
            # need to update that manually
            existing_gs_sqa.experiment_id = gs_sqa.experiment_id

    return generation_strategy._db_id
コード例 #12
0
ファイル: load.py プロジェクト: facebook/Ax
def _get_experiment_sqa(
    experiment_name: str,
    exp_sqa_class: Type[SQAExperiment],
    trial_sqa_class: Type[SQATrial],
    trials_query_options: Optional[List[Any]] = None,
    load_trials_in_batches_of_size: Optional[int] = None,
) -> SQAExperiment:
    """Obtains SQLAlchemy experiment object from DB."""
    with session_scope() as session:
        query = (
            session.query(exp_sqa_class).filter_by(name=experiment_name)
            # Delay loading trials to a separate call to `_get_trials_sqa` below
            .options(noload("trials")))
        sqa_experiment = query.one_or_none()

    if sqa_experiment is None:
        raise ValueError(f"Experiment '{experiment_name}' not found.")

    sqa_trials = _get_trials_sqa(
        experiment_id=sqa_experiment.id,
        trial_sqa_class=trial_sqa_class,
        trials_query_options=trials_query_options,
        load_trials_in_batches_of_size=load_trials_in_batches_of_size,
    )

    sqa_experiment.trials = sqa_trials

    return sqa_experiment
コード例 #13
0
def _save_experiment(
    experiment: Experiment,
    encoder: Encoder,
    return_sqa: bool = False,
    validation_kwargs: Optional[Dict[str, Any]] = None,
) -> Optional[SQABase]:
    """Save experiment, using given Encoder instance.

    1) Convert Ax object to SQLAlchemy object.
    2) Determine if there is an existing experiment with that name in the DB.
    3) If not, create a new one.
    4) If so, update the old one.
        The update works by merging the new SQLAlchemy object into the
        existing SQLAlchemy object, and then letting SQLAlchemy handle the
        actual DB updates.
    """
    # Convert user-facing class to SQA outside of session scope to avoid timeouts
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        existing_sqa_experiment = (session.query(exp_sqa_class).filter_by(
            name=experiment.name).one_or_none())
    encoder.validate_experiment_metadata(
        experiment,
        # pyre-fixme[6]: Expected
        #  `Optional[ax.storage.sqa_store.sqa_classes.SQAExperiment]` for 2nd param but
        #  got `Optional[ax.storage.sqa_store.db.SQABase]`.
        existing_sqa_experiment=existing_sqa_experiment,
        **(validation_kwargs or {}),
    )
    new_sqa_experiment, obj_to_sqa = encoder.experiment_to_sqa(experiment)

    if existing_sqa_experiment is not None:
        # Update the SQA object outside of session scope to avoid timeouts.
        # This object is detached from the session, but contains a database
        # identity marker, so when we do `session.add` below, SQA knows to
        # perform an update rather than an insert.
        # pyre-fixme[6]: Expected `SQABase` for 1st param but got `SQAExperiment`.
        existing_sqa_experiment.update(new_sqa_experiment)
        new_sqa_experiment = existing_sqa_experiment

    with session_scope() as session:
        session.add(new_sqa_experiment)
        session.flush()

    _set_db_ids(obj_to_sqa=obj_to_sqa)

    return checked_cast(SQABase, new_sqa_experiment) if return_sqa else None
コード例 #14
0
def _get_experiment_sqa(experiment_name: str) -> SQAExperiment:
    """Obtains SQLAlchemy experiment object from DB."""
    with session_scope() as session:
        sqa_experiment = (session.query(SQAExperiment).filter_by(
            name=experiment_name).one_or_none())
        if sqa_experiment is None:
            raise ValueError(f"Experiment `{experiment_name}` not found.")
    return sqa_experiment
コード例 #15
0
ファイル: save.py プロジェクト: isabella232/Ax
def _bulk_merge_into_session(
    objs: Sequence[Base],
    encode_func: Callable,
    decode_func: Callable,
    encode_args_list: Optional[Union[List[None], List[Dict[str, Any]]]] = None,
    decode_args_list: Optional[Union[List[None], List[Dict[str, Any]]]] = None,
    modify_sqa: Optional[Callable] = None,
    batch_size: Optional[int] = None,
) -> List[SQABase]:
    """Bulk version of _merge_into_session.

    Takes in a list of objects to merge into the session together
    (i.e. within one session scope), along with corresponding (but optional)
    lists of encode and decode arguments.

    If batch_size is specified, the list of objects will be chunked
    accordingly, and multiple session scopes will be used to merge
    the objects in, one batch at a time.
    """
    if len(objs) == 0:
        return []

    encode_func = _standardize_encode_func(encode_func=encode_func)
    encode_args_list = encode_args_list or [None for _ in range(len(objs))]
    decode_args_list = decode_args_list or [None for _ in range(len(objs))]

    sqas = []
    for obj, encode_args in zip(objs, encode_args_list):
        sqa = encode_func(obj, **(encode_args or {}))
        if modify_sqa is not None:
            modify_sqa(sqa=sqa)
        sqas.append(sqa)

    # https://stackoverflow.com/a/312464
    def split_into_batches(lst, n):
        for i in range(0, len(lst), n):
            yield lst[i : i + n]

    new_sqas = []
    batch_size = batch_size or len(sqas)
    for batch in split_into_batches(lst=sqas, n=batch_size):
        with session_scope() as session:
            for sqa in batch:
                new_sqa = session.merge(sqa)
                new_sqas.append(new_sqa)
            session.flush()

    new_objs = []
    for new_sqa, decode_args in zip(new_sqas, decode_args_list):
        new_obj = decode_func(new_sqa, **(decode_args or {}))
        new_objs.append(new_obj)

    for obj, new_obj in zip(objs, new_objs):
        _copy_db_ids_if_possible(obj=obj, new_obj=new_obj)

    return new_sqas
コード例 #16
0
ファイル: save.py プロジェクト: Vilashcj/Ax
def _save_generation_strategy(generation_strategy: GenerationStrategy,
                              encoder: Encoder) -> int:
    # If the generation strategy has not yet generated anything, there will be no
    # experiment set on it.
    if generation_strategy._experiment is None:
        experiment_id = None
    else:
        # Experiment was set on the generation strategy, so need to check whether
        # if has been saved and create a relationship b/w GS and experiment if so.
        experiment_id = _get_experiment_id(
            # pyre-fixme[6]: Expected `Experiment` for 1st param but got
            #  `Optional[Experiment]`.
            experiment=generation_strategy._experiment,
            encoder=encoder,
        )

    gs_sqa = encoder.generation_strategy_to_sqa(
        generation_strategy=generation_strategy, experiment_id=experiment_id)

    if generation_strategy._db_id is not None:
        gs_sqa_class = encoder.config.class_to_sqa_class[GenerationStrategy]
        with session_scope() as session:
            existing_gs_sqa = session.query(gs_sqa_class).get(
                generation_strategy._db_id)

        # pyre-fixme[16]: `Optional` has no attribute `update`.
        existing_gs_sqa.update(gs_sqa)
        # our update logic ignores foreign keys, i.e. fields ending in _id,
        # because we want SQLAlchemy to handle those relationships for us
        # however, generation_strategy.experiment_id is an exception, so we
        # need to update that manually
        # pyre-fixme[16]: `Optional` has no attribute `experiment_id`.
        existing_gs_sqa.experiment_id = gs_sqa.experiment_id
        gs_sqa = existing_gs_sqa

    with session_scope() as session:
        session.add(gs_sqa)
        session.flush()  # Ensures generation strategy id is set.

    # pyre-fixme[16]: `None` has no attribute `id`.
    generation_strategy._db_id = gs_sqa.id
    # pyre-fixme[7]: Expected `int` but got `Optional[int]`.
    return generation_strategy._db_id
コード例 #17
0
ファイル: save.py プロジェクト: AdrianaMusic/Ax
def _update_generation_strategy(
    generation_strategy: GenerationStrategy,
    generator_runs: List[GeneratorRun],
    encoder: Encoder,
) -> None:
    """Update generation strategy's current step and attach generator runs."""
    gs_sqa_class = encoder.config.class_to_sqa_class[GenerationStrategy]

    gs_id = generation_strategy.db_id
    if gs_id is None:
        raise ValueError(
            "GenerationStrategy must be saved before being updated.")

    if any(gr.db_id for gr in generator_runs):
        raise ValueError("Can only save new GeneratorRuns.")

    experiment_id = generation_strategy.experiment.db_id
    if experiment_id is None:
        raise ValueError(  # pragma: no cover
            f"Experiment {generation_strategy.experiment.name} "
            "should be saved before generation strategy.")

    obj_to_sqa = []
    with session_scope() as session:
        session.query(gs_sqa_class).filter_by(id=gs_id).update({
            "curr_index":
            generation_strategy._curr.index,
            "experiment_id":
            experiment_id,
        })

    generator_runs_sqa = []
    for generator_run in generator_runs:
        gr_sqa, _obj_to_sqa = encoder.generator_run_to_sqa(
            generator_run=generator_run)
        obj_to_sqa.extend(_obj_to_sqa)
        gr_sqa.generation_strategy_id = gs_id
        generator_runs_sqa.append(gr_sqa)

    with session_scope() as session:
        session.add_all(generator_runs_sqa)

    _set_db_ids(obj_to_sqa=obj_to_sqa)
コード例 #18
0
def _get_experiment_id(experiment: Experiment, encoder: Encoder) -> int:
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        sqa_experiment = (session.query(exp_sqa_class).filter_by(
            name=experiment.name).one_or_none())
    if sqa_experiment is None:  # pragma: no cover (this is technically unreachable)
        raise ValueError(
            "The undelying experiment must be saved before the generation strategy."
        )
    return sqa_experiment.id
コード例 #19
0
ファイル: save.py プロジェクト: stevemandala/Ax
def _save_experiment(
    experiment: Experiment, encoder: Encoder, overwrite: bool = False
) -> None:
    """Save experiment, using given Encoder instance.

    1) Convert Ax object to SQLAlchemy object.
    2) Determine if there is an existing experiment with that name in the DB.
    3) If not, create a new one.
    4) If so, update the old one.
        The update works by merging the new SQLAlchemy object into the
        existing SQLAlchemy object, and then letting SQLAlchemy handle the
        actual DB updates.
    """
    # Convert user-facing class to SQA outside of session scope to avoid timeouts
    new_sqa_experiment = encoder.experiment_to_sqa(experiment)
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        existing_sqa_experiment = (
            session.query(exp_sqa_class).filter_by(name=experiment.name).one_or_none()
        )

    if existing_sqa_experiment is not None:
        if (
            not datetime_equals(
                existing_sqa_experiment.time_created, new_sqa_experiment.time_created
            )
            and overwrite is False
        ):
            raise Exception(
                "An experiment already exists with the name "
                f"{new_sqa_experiment.name}. To overwrite, specify "
                "`overwrite = True` when calling `save_experiment`."
            )

        # Update the SQA object outside of session scope to avoid timeouts.
        # This object is detached from the session, but contains a database
        # identity marker, so when we do `session.add` below, SQA knows to
        # perform an update rather than an insert.
        existing_sqa_experiment.update(new_sqa_experiment)
        new_sqa_experiment = existing_sqa_experiment

    with session_scope() as session:
        session.add(new_sqa_experiment)
コード例 #20
0
ファイル: load.py プロジェクト: viotemp1/Ax
def _get_generation_strategy_sqa(
    gs_id: int, decoder: Decoder, reduced_state: bool = False
) -> SQAGenerationStrategy:
    """Obtains most of the SQLAlchemy experiment object from DB."""
    gs_sqa_class = cast(
        Type[SQAGenerationStrategy],
        decoder.config.class_to_sqa_class[GenerationStrategy],
    )
    gr_sqa_class = cast(
        Type[SQAGeneratorRun],
        decoder.config.class_to_sqa_class[GeneratorRun],
    )
    with session_scope() as session:
        query = session.query(gs_sqa_class).filter_by(id=gs_id)
        if reduced_state:
            query = query.options(
                lazyload("generator_runs.parameters"),
                lazyload("generator_runs.parameter_constraints"),
                lazyload("generator_runs.metrics"),
                defaultload(gs_sqa_class.generator_runs).defer("model_kwargs"),
                defaultload(gs_sqa_class.generator_runs).defer("bridge_kwargs"),
                defaultload(gs_sqa_class.generator_runs).defer("model_state_after_gen"),
                defaultload(gs_sqa_class.generator_runs).defer("gen_metadata"),
            )
        gs_sqa = query.one_or_none()
    if gs_sqa is None:
        raise ValueError(f"Generation strategy with ID #{gs_id} not found.")

    # Load full last generator run (including model state), for generation
    # strategy restoration, if loading reduced state.
    if reduced_state and gs_sqa.generator_runs:
        last_generator_run_id = gs_sqa.generator_runs[-1].id
        with session_scope() as session:
            last_gr_sqa = (
                session.query(gr_sqa_class)
                .filter_by(id=last_generator_run_id)
                .one_or_none()
            )
        # Swap last generator run with no state for a generator run with
        # state.
        gs_sqa.generator_runs[len(gs_sqa.generator_runs) - 1] = last_gr_sqa

    return gs_sqa
コード例 #21
0
def delete_experiment(exp_name: str) -> None:
    """Delete experiment by name.

    Args:
        experiment_name: Name of the experiment to delete.
    """
    with session_scope() as session:
        exp = session.query(SQAExperiment).filter_by(name=exp_name).one_or_none()
        session.delete(exp)
        session.flush()
コード例 #22
0
def _get_experiment_sqa(experiment_name: str, decoder: Decoder) -> SQAExperiment:
    """Obtains SQLAlchemy experiment object from DB."""
    exp_sqa_class = decoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        sqa_experiment = (
            session.query(exp_sqa_class).filter_by(name=experiment_name).one_or_none()
        )
        if sqa_experiment is None:
            raise ValueError(f"Experiment `{experiment_name}` not found.")
    return sqa_experiment  # pyre-ignore[7]
コード例 #23
0
def _save_generation_strategy(generation_strategy: GenerationStrategy,
                              encoder: Encoder) -> int:
    # If the generation strategy has not yet generated anything, there will be no
    # experiment set on it.
    experiment = generation_strategy._experiment
    if experiment is None:
        experiment_id = None
    else:
        # Experiment was set on the generation strategy, so need to check whether
        # if has been saved and create a relationship b/w GS and experiment if so.
        experiment_id = experiment.db_id
        if experiment_id is None:
            raise ValueError(  # pragma: no cover
                f"Experiment {experiment.name} should be saved before "
                "generation strategy.")

    gs_sqa, obj_to_sqa = encoder.generation_strategy_to_sqa(
        generation_strategy=generation_strategy, experiment_id=experiment_id)

    if generation_strategy._db_id is not None:
        gs_sqa_class = encoder.config.class_to_sqa_class[GenerationStrategy]
        with session_scope() as session:
            existing_gs_sqa = session.query(gs_sqa_class).get(
                generation_strategy._db_id)

        # pyre-fixme[16]: `Optional` has no attribute `update`.
        existing_gs_sqa.update(gs_sqa)
        # our update logic ignores foreign keys, i.e. fields ending in _id,
        # because we want SQLAlchemy to handle those relationships for us
        # however, generation_strategy.experiment_id is an exception, so we
        # need to update that manually
        # pyre-fixme[16]: `Optional` has no attribute `experiment_id`.
        existing_gs_sqa.experiment_id = gs_sqa.experiment_id
        gs_sqa = existing_gs_sqa

    with session_scope() as session:
        session.add(gs_sqa)
        session.flush()  # Ensures generation strategy id is set.

    _set_db_ids(obj_to_sqa=obj_to_sqa)

    return not_none(generation_strategy.db_id)
コード例 #24
0
ファイル: save.py プロジェクト: stevemandala/Ax
def _save_new_trial(experiment: Experiment, trial: BaseTrial, encoder: Encoder) -> None:
    """Add new trial to the experiment, using given Encoder instance."""
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        existing_sqa_experiment = (
            session.query(exp_sqa_class).filter_by(name=experiment.name).one_or_none()
        )

    if existing_sqa_experiment is None:
        raise ValueError("Must save experiment before adding a new trial.")

    existing_trial_indices = {trial.index for trial in existing_sqa_experiment.trials}
    if trial.index in existing_trial_indices:
        raise ValueError(f"Trial {trial.index} already attached to experiment.")

    new_sqa_trial = encoder.trial_to_sqa(trial)

    with session_scope() as session:
        existing_sqa_experiment.trials.append(new_sqa_trial)
        session.add(existing_sqa_experiment)
コード例 #25
0
ファイル: save.py プロジェクト: linusec/Ax
def _update_trial(experiment: Experiment, trial: BaseTrial,
                  encoder: Encoder) -> None:
    """Update trial and attach data, using given Encoder instance."""
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        existing_sqa_experiment = (session.query(exp_sqa_class).filter_by(
            name=experiment.name).one_or_none())

    if existing_sqa_experiment is None:
        raise ValueError("Must save experiment before updating a trial.")

    existing_trial_indices = {
        trial.index
        for trial in existing_sqa_experiment.trials
    }
    if trial.index not in existing_trial_indices:
        raise ValueError(
            f"Trial {trial.index} is not attached to the experiment.")

    # There should only be one existing trial with the same index
    existing_sqa_trials = [
        sqa_trial for sqa_trial in existing_sqa_experiment.trials
        if sqa_trial.index == trial.index
    ]
    assert len(existing_sqa_trials) == 1
    existing_sqa_trial = existing_sqa_trials[0]

    new_sqa_trial = encoder.trial_to_sqa(trial)
    existing_sqa_trial.update(new_sqa_trial)

    with session_scope() as session:
        session.add(existing_sqa_trial)

    data, ts = experiment.lookup_data_for_trial(trial_index=trial.index)
    if ts != -1:
        sqa_data = encoder.data_to_sqa(data=data,
                                       trial_index=trial.index,
                                       timestamp=ts)
        with session_scope() as session:
            existing_sqa_experiment.data.append(sqa_data)
            session.add(existing_sqa_experiment)
コード例 #26
0
ファイル: decoder.py プロジェクト: Balandat/Ax
def _get_scalarized_objective_children_metrics(
        metric_id: int, decoder: Decoder) -> List[SQAMetric]:
    """Given a metric db id, fetch its scalarized objective children metrics."""
    metric_sqa_class = cast(
        Type[SQAMetric],
        decoder.config.class_to_sqa_class[Metric],
    )
    with session_scope() as session:
        query = session.query(metric_sqa_class).filter_by(
            scalarized_objective_id=metric_id)
        metrics_sqa = query.all()
    return metrics_sqa
コード例 #27
0
    def testRunnerValidation(self):
        sqa_runner = SQARunner(runner_type=RUNNER_REGISTRY[SyntheticRunner])
        with self.assertRaises(ValueError):
            with session_scope() as session:
                session.add(sqa_runner)

        sqa_runner.experiment_id = 0
        with session_scope() as session:
            session.add(sqa_runner)
        with self.assertRaises(ValueError):
            sqa_runner.trial_id = 0
            with session_scope() as session:
                session.add(sqa_runner)

        sqa_runner = SQARunner(runner_type=RUNNER_REGISTRY[SyntheticRunner], trial_id=0)
        with session_scope() as session:
            session.add(sqa_runner)
        with self.assertRaises(ValueError):
            sqa_runner.experiment_id = 0
            with session_scope() as session:
                session.add(sqa_runner)
コード例 #28
0
def _update_generation_strategy(
    generation_strategy: GenerationStrategy,
    generator_runs: List[GeneratorRun],
    encoder: Encoder,
    decoder: Decoder,
    batch_size: Optional[int] = None,
    reduce_state_generator_runs: bool = False,
) -> None:
    """Update generation strategy's current step and attach generator runs."""
    gs_sqa_class = encoder.config.class_to_sqa_class[GenerationStrategy]

    gs_id = generation_strategy.db_id
    if gs_id is None:
        raise ValueError(
            "GenerationStrategy must be saved before being updated.")

    experiment_id = generation_strategy.experiment.db_id
    if experiment_id is None:
        raise ValueError(  # pragma: no cover
            f"Experiment {generation_strategy.experiment.name} "
            "should be saved before generation strategy.")

    with session_scope() as session:
        session.query(gs_sqa_class).filter_by(id=gs_id).update({
            "curr_index":
            generation_strategy._curr.index,
            "experiment_id":
            experiment_id,
        })

    def add_generation_strategy_id(sqa: SQAGeneratorRun):
        sqa.generation_strategy_id = gs_id

    def generator_run_to_sqa_encoder(gr: GeneratorRun,
                                     weight: Optional[float] = None):
        return encoder.generator_run_to_sqa(
            gr,
            weight=weight,
            reduced_state=reduce_state_generator_runs,
        )

    _bulk_merge_into_session(
        objs=generator_runs,
        encode_func=generator_run_to_sqa_encoder,
        decode_func=decoder.generator_run_from_sqa,
        decode_args_list=[{
            "reduced_state": False,
            "immutable_search_space_and_opt_config": False,
        } for _ in range(len(generator_runs))],
        modify_sqa=add_generation_strategy_id,
        batch_size=batch_size,
    )
コード例 #29
0
ファイル: load.py プロジェクト: zorrock/Ax
def _load_experiment(experiment_name: str, decoder: Decoder) -> Experiment:
    """Load experiment by name, using given Decoder instance.

    1) Get SQLAlchemy object from DB.
    2) Convert to corresponding Ax object.
    """
    with session_scope() as session:
        sqa_experiment = (
            session.query(SQAExperiment).filter_by(name=experiment_name).one_or_none()
        )
        if sqa_experiment is None:
            raise ValueError(f"Experiment `{experiment_name}` not found.")
        return decoder.experiment_from_sqa(sqa_experiment)
コード例 #30
0
ファイル: load.py プロジェクト: isabella232/Ax
def _get_experiment_id(experiment_name: str,
                       config: SQAConfig) -> Optional[int]:
    """Get DB ID of the experiment by the given name if its in DB,
    return None otherwise.
    """
    exp_sqa_class = config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        sqa_experiment_id = (
            session.query(exp_sqa_class.id)  # pyre-ignore
            .filter_by(name=experiment_name).one_or_none())

    if sqa_experiment_id is None:
        return None
    return sqa_experiment_id[0]