Exemplo n.º 1
0
def _update_trials(experiment: Experiment, trials: List[BaseTrial],
                   encoder: Encoder) -> None:
    """Update trials and attach data."""
    trial_sqa_class = encoder.config.class_to_sqa_class[Trial]
    trial_indices = [trial.index for trial in trials]
    obj_to_sqa = []
    with session_scope() as session:
        experiment_id = _get_experiment_id(experiment=experiment,
                                           encoder=encoder,
                                           session=session)
        existing_trials = (
            session.query(trial_sqa_class).filter_by(
                experiment_id=experiment_id).filter(
                    trial_sqa_class.index.in_(trial_indices))  # pyre-ignore
            .all())

    trial_index_to_existing_trial = {
        trial.index: trial
        for trial in existing_trials
    }

    updated_sqa_trials, new_sqa_data = [], []
    for trial in trials:
        existing_trial = trial_index_to_existing_trial.get(trial.index)
        if existing_trial is None:
            raise ValueError(
                f"Trial {trial.index} is not attached to the experiment.")

        new_sqa_trial, _obj_to_sqa = encoder.trial_to_sqa(trial)
        obj_to_sqa.extend(_obj_to_sqa)
        existing_trial.update(new_sqa_trial)
        updated_sqa_trials.append(existing_trial)

        data, ts = experiment.lookup_data_for_trial(trial_index=trial.index)
        if ts != -1:
            sqa_data = encoder.data_to_sqa(data=data,
                                           trial_index=trial.index,
                                           timestamp=ts)
            obj_to_sqa.append((data, sqa_data))
            sqa_data.experiment_id = experiment_id
            new_sqa_data.append(sqa_data)

    with session_scope() as session:
        session.add_all(updated_sqa_trials)
        session.add_all(new_sqa_data)
        session.flush()

    _set_db_ids(obj_to_sqa=obj_to_sqa)
Exemplo n.º 2
0
Arquivo: save.py Projeto: linusec/Ax
def _update_trial(experiment: Experiment, trial: BaseTrial,
                  encoder: Encoder) -> None:
    """Update trial and attach data, using given Encoder instance."""
    exp_sqa_class = encoder.config.class_to_sqa_class[Experiment]
    with session_scope() as session:
        existing_sqa_experiment = (session.query(exp_sqa_class).filter_by(
            name=experiment.name).one_or_none())

    if existing_sqa_experiment is None:
        raise ValueError("Must save experiment before updating a trial.")

    existing_trial_indices = {
        trial.index
        for trial in existing_sqa_experiment.trials
    }
    if trial.index not in existing_trial_indices:
        raise ValueError(
            f"Trial {trial.index} is not attached to the experiment.")

    # There should only be one existing trial with the same index
    existing_sqa_trials = [
        sqa_trial for sqa_trial in existing_sqa_experiment.trials
        if sqa_trial.index == trial.index
    ]
    assert len(existing_sqa_trials) == 1
    existing_sqa_trial = existing_sqa_trials[0]

    new_sqa_trial = encoder.trial_to_sqa(trial)
    existing_sqa_trial.update(new_sqa_trial)

    with session_scope() as session:
        session.add(existing_sqa_trial)

    data, ts = experiment.lookup_data_for_trial(trial_index=trial.index)
    if ts != -1:
        sqa_data = encoder.data_to_sqa(data=data,
                                       trial_index=trial.index,
                                       timestamp=ts)
        with session_scope() as session:
            existing_sqa_experiment.data.append(sqa_data)
            session.add(existing_sqa_experiment)
Exemplo n.º 3
0
Arquivo: save.py Projeto: Vilashcj/Ax
def update_trials(experiment: Experiment, trials: List[BaseTrial],
                  encoder: Encoder) -> None:
    """Update trials and attach data."""
    trial_sqa_class = encoder.config.class_to_sqa_class[Trial]
    with session_scope() as session:
        experiment_id = _get_experiment_id(experiment=experiment,
                                           encoder=encoder,
                                           session=session)
        trial_indices = [trial.index for trial in trials]
        existing_trials = (
            session.query(trial_sqa_class).filter_by(
                experiment_id=experiment_id).filter(
                    trial_sqa_class.index.in_(trial_indices))  # pyre-ignore
            .all())
        trial_index_to_existing_trial = {
            trial.index: trial
            for trial in existing_trials
        }

        for trial in trials:
            existing_trial = trial_index_to_existing_trial.get(trial.index)
            if existing_trial is None:
                raise ValueError(
                    f"Trial {trial.index} is not attached to the experiment.")

            new_sqa_trial = encoder.trial_to_sqa(trial)
            existing_trial.update(new_sqa_trial)
            session.add(existing_trial)

            data, ts = experiment.lookup_data_for_trial(
                trial_index=trial.index)
            if ts != -1:
                sqa_data = encoder.data_to_sqa(data=data,
                                               trial_index=trial.index,
                                               timestamp=ts)
                sqa_data.experiment_id = experiment_id
                session.add(sqa_data)
Exemplo n.º 4
0
def _save_or_update_trials(experiment: Experiment, trials: List[BaseTrial],
                           encoder: Encoder) -> None:
    """Add new trials to the experiment, or update if they already exist."""
    experiment_id = experiment._db_id
    if experiment_id is None:
        raise ValueError("Must save experiment first.")

    data_sqa_class = encoder.config.class_to_sqa_class[Data]
    trial_sqa_class = encoder.config.class_to_sqa_class[Trial]
    obj_to_sqa = []
    with session_scope() as session:
        # Fetch the ids of all trials already saved to the experiment
        existing_trial_ids = (
            session.query(trial_sqa_class.id)  # pyre-ignore
            .filter_by(experiment_id=experiment_id).all())

    existing_trial_ids = {x[0] for x in existing_trial_ids}

    update_trial_ids = set()
    update_trial_indices = set()
    for trial in trials:
        if trial._db_id not in existing_trial_ids:
            continue
        update_trial_ids.add(trial._db_id)
        update_trial_indices.add(trial.index)

    # We specifically fetch the *whole* trial (and corresponding data)
    # for old trials that we need to update.
    # We could fetch the whole trial for all trials attached to the experiment,
    # and therefore combine this call with the one above, but that might be
    # unnecessarily costly if we're not updating many or any trials.
    with session_scope() as session:
        existing_trials = (session.query(trial_sqa_class).filter(
            trial_sqa_class.id.in_(update_trial_ids)).all())

    with session_scope() as session:
        existing_data = (
            session.query(data_sqa_class).filter_by(
                experiment_id=experiment_id).filter(
                    data_sqa_class.trial_index.in_(
                        update_trial_indices))  # pyre-ignore
            .all())

    trial_id_to_existing_trial = {trial.id: trial for trial in existing_trials}
    data_id_to_existing_data = {data.id: data for data in existing_data}

    sqa_trials, sqa_datas = [], []
    for trial in trials:
        sqa_trial, _obj_to_sqa = encoder.trial_to_sqa(trial)
        obj_to_sqa.extend(_obj_to_sqa)

        existing_trial = trial_id_to_existing_trial.get(trial._db_id)
        if existing_trial is None:
            sqa_trial.experiment_id = experiment_id
            sqa_trials.append(sqa_trial)
        else:
            existing_trial.update(sqa_trial)
            sqa_trials.append(existing_trial)

        datas = experiment.data_by_trial.get(trial.index, {})
        for ts, data in datas.items():
            sqa_data = encoder.data_to_sqa(data=data,
                                           trial_index=trial.index,
                                           timestamp=ts)
            obj_to_sqa.append((data, sqa_data))

            existing_data = data_id_to_existing_data.get(data._db_id)
            if existing_data is None:
                sqa_data.experiment_id = experiment_id
                sqa_datas.append(sqa_data)
            else:
                existing_data.update(sqa_data)
                sqa_datas.append(existing_data)

    with session_scope() as session:
        session.add_all(sqa_trials)
        session.add_all(sqa_datas)
        session.flush()

    _set_db_ids(obj_to_sqa=obj_to_sqa)