Exemplo n.º 1
0
def _update_trials(experiment: Experiment, trials: List[BaseTrial],
                   encoder: Encoder) -> None:
    """Update trials and attach data."""
    trial_sqa_class = encoder.config.class_to_sqa_class[Trial]
    trial_indices = [trial.index for trial in trials]
    obj_to_sqa = []
    with session_scope() as session:
        experiment_id = _get_experiment_id(experiment=experiment,
                                           encoder=encoder,
                                           session=session)
        existing_trials = (
            session.query(trial_sqa_class).filter_by(
                experiment_id=experiment_id).filter(
                    trial_sqa_class.index.in_(trial_indices))  # pyre-ignore
            .all())

    trial_index_to_existing_trial = {
        trial.index: trial
        for trial in existing_trials
    }

    updated_sqa_trials, new_sqa_data = [], []
    for trial in trials:
        existing_trial = trial_index_to_existing_trial.get(trial.index)
        if existing_trial is None:
            raise ValueError(
                f"Trial {trial.index} is not attached to the experiment.")

        new_sqa_trial, _obj_to_sqa = encoder.trial_to_sqa(trial)
        obj_to_sqa.extend(_obj_to_sqa)
        existing_trial.update(new_sqa_trial)
        updated_sqa_trials.append(existing_trial)

        data, ts = experiment.lookup_data_for_trial(trial_index=trial.index)
        if ts != -1:
            sqa_data = encoder.data_to_sqa(data=data,
                                           trial_index=trial.index,
                                           timestamp=ts)
            obj_to_sqa.append((data, sqa_data))
            sqa_data.experiment_id = experiment_id
            new_sqa_data.append(sqa_data)

    with session_scope() as session:
        session.add_all(updated_sqa_trials)
        session.add_all(new_sqa_data)
        session.flush()

    _set_db_ids(obj_to_sqa=obj_to_sqa)
Exemplo n.º 2
0
Arquivo: save.py Projeto: linusec/Ax
def _update_trial(experiment: Experiment, trial: BaseTrial,
                  encoder: Encoder) -> None:
    """Update trial and attach data, using given Encoder instance."""
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        existing_sqa_experiment = (session.query(exp_sqa_class).filter_by(
            name=experiment.name).one_or_none())

    if existing_sqa_experiment is None:
        raise ValueError("Must save experiment before updating a trial.")

    existing_trial_indices = {
        trial.index
        for trial in existing_sqa_experiment.trials
    }
    if trial.index not in existing_trial_indices:
        raise ValueError(
            f"Trial {trial.index} is not attached to the experiment.")

    # There should only be one existing trial with the same index
    existing_sqa_trials = [
        sqa_trial for sqa_trial in existing_sqa_experiment.trials
        if sqa_trial.index == trial.index
    ]
    assert len(existing_sqa_trials) == 1
    existing_sqa_trial = existing_sqa_trials[0]

    new_sqa_trial = encoder.trial_to_sqa(trial)
    existing_sqa_trial.update(new_sqa_trial)

    with session_scope() as session:
        session.add(existing_sqa_trial)

    data, ts = experiment.lookup_data_for_trial(trial_index=trial.index)
    if ts != -1:
        sqa_data = encoder.data_to_sqa(data=data,
                                       trial_index=trial.index,
                                       timestamp=ts)
        with session_scope() as session:
            existing_sqa_experiment.data.append(sqa_data)
            session.add(existing_sqa_experiment)
Exemplo n.º 3
0
Arquivo: save.py Projeto: Vilashcj/Ax
def update_trials(experiment: Experiment, trials: List[BaseTrial],
                  encoder: Encoder) -> None:
    """Update trials and attach data."""
    trial_sqa_class = encoder.config.class_to_sqa_class[Trial]
    with session_scope() as session:
        experiment_id = _get_experiment_id(experiment=experiment,
                                           encoder=encoder,
                                           session=session)
        trial_indices = [trial.index for trial in trials]
        existing_trials = (
            session.query(trial_sqa_class).filter_by(
                experiment_id=experiment_id).filter(
                    trial_sqa_class.index.in_(trial_indices))  # pyre-ignore
            .all())
        trial_index_to_existing_trial = {
            trial.index: trial
            for trial in existing_trials
        }

        for trial in trials:
            existing_trial = trial_index_to_existing_trial.get(trial.index)
            if existing_trial is None:
                raise ValueError(
                    f"Trial {trial.index} is not attached to the experiment.")

            new_sqa_trial = encoder.trial_to_sqa(trial)
            existing_trial.update(new_sqa_trial)
            session.add(existing_trial)

            data, ts = experiment.lookup_data_for_trial(
                trial_index=trial.index)
            if ts != -1:
                sqa_data = encoder.data_to_sqa(data=data,
                                               trial_index=trial.index,
                                               timestamp=ts)
                sqa_data.experiment_id = experiment_id
                session.add(sqa_data)