def _update_trials(experiment: Experiment, trials: List[BaseTrial], encoder: Encoder) -> None: """Update trials and attach data.""" trial_sqa_class = encoder.config.class_to_sqa_class[Trial] trial_indices = [trial.index for trial in trials] obj_to_sqa = [] with session_scope() as session: experiment_id = _get_experiment_id(experiment=experiment, encoder=encoder, session=session) existing_trials = ( session.query(trial_sqa_class).filter_by( experiment_id=experiment_id).filter( trial_sqa_class.index.in_(trial_indices)) # pyre-ignore .all()) trial_index_to_existing_trial = { trial.index: trial for trial in existing_trials } updated_sqa_trials, new_sqa_data = [], [] for trial in trials: existing_trial = trial_index_to_existing_trial.get(trial.index) if existing_trial is None: raise ValueError( f"Trial {trial.index} is not attached to the experiment.") new_sqa_trial, _obj_to_sqa = encoder.trial_to_sqa(trial) obj_to_sqa.extend(_obj_to_sqa) existing_trial.update(new_sqa_trial) updated_sqa_trials.append(existing_trial) data, ts = experiment.lookup_data_for_trial(trial_index=trial.index) if ts != -1: sqa_data = encoder.data_to_sqa(data=data, trial_index=trial.index, timestamp=ts) obj_to_sqa.append((data, sqa_data)) sqa_data.experiment_id = experiment_id new_sqa_data.append(sqa_data) with session_scope() as session: session.add_all(updated_sqa_trials) session.add_all(new_sqa_data) session.flush() _set_db_ids(obj_to_sqa=obj_to_sqa)
def _update_trial(experiment: Experiment, trial: BaseTrial, encoder: Encoder) -> None: """Update trial and attach data, using given Encoder instance.""" exp_sqa_class = encoder.config.class_to_sqa_class[Experiment] with session_scope() as session: existing_sqa_experiment = (session.query(exp_sqa_class).filter_by( name=experiment.name).one_or_none()) if existing_sqa_experiment is None: raise ValueError("Must save experiment before updating a trial.") existing_trial_indices = { trial.index for trial in existing_sqa_experiment.trials } if trial.index not in existing_trial_indices: raise ValueError( f"Trial {trial.index} is not attached to the experiment.") # There should only be one existing trial with the same index existing_sqa_trials = [ sqa_trial for sqa_trial in existing_sqa_experiment.trials if sqa_trial.index == trial.index ] assert len(existing_sqa_trials) == 1 existing_sqa_trial = existing_sqa_trials[0] new_sqa_trial = encoder.trial_to_sqa(trial) existing_sqa_trial.update(new_sqa_trial) with session_scope() as session: session.add(existing_sqa_trial) data, ts = experiment.lookup_data_for_trial(trial_index=trial.index) if ts != -1: sqa_data = encoder.data_to_sqa(data=data, trial_index=trial.index, timestamp=ts) with session_scope() as session: existing_sqa_experiment.data.append(sqa_data) session.add(existing_sqa_experiment)
def update_trials(experiment: Experiment, trials: List[BaseTrial], encoder: Encoder) -> None: """Update trials and attach data.""" trial_sqa_class = encoder.config.class_to_sqa_class[Trial] with session_scope() as session: experiment_id = _get_experiment_id(experiment=experiment, encoder=encoder, session=session) trial_indices = [trial.index for trial in trials] existing_trials = ( session.query(trial_sqa_class).filter_by( experiment_id=experiment_id).filter( trial_sqa_class.index.in_(trial_indices)) # pyre-ignore .all()) trial_index_to_existing_trial = { trial.index: trial for trial in existing_trials } for trial in trials: existing_trial = trial_index_to_existing_trial.get(trial.index) if existing_trial is None: raise ValueError( f"Trial {trial.index} is not attached to the experiment.") new_sqa_trial = encoder.trial_to_sqa(trial) existing_trial.update(new_sqa_trial) session.add(existing_trial) data, ts = experiment.lookup_data_for_trial( trial_index=trial.index) if ts != -1: sqa_data = encoder.data_to_sqa(data=data, trial_index=trial.index, timestamp=ts) sqa_data.experiment_id = experiment_id session.add(sqa_data)
def _save_or_update_trials(experiment: Experiment, trials: List[BaseTrial], encoder: Encoder) -> None: """Add new trials to the experiment, or update if they already exist.""" experiment_id = experiment._db_id if experiment_id is None: raise ValueError("Must save experiment first.") data_sqa_class = encoder.config.class_to_sqa_class[Data] trial_sqa_class = encoder.config.class_to_sqa_class[Trial] obj_to_sqa = [] with session_scope() as session: # Fetch the ids of all trials already saved to the experiment existing_trial_ids = ( session.query(trial_sqa_class.id) # pyre-ignore .filter_by(experiment_id=experiment_id).all()) existing_trial_ids = {x[0] for x in existing_trial_ids} update_trial_ids = set() update_trial_indices = set() for trial in trials: if trial._db_id not in existing_trial_ids: continue update_trial_ids.add(trial._db_id) update_trial_indices.add(trial.index) # We specifically fetch the *whole* trial (and corresponding data) # for old trials that we need to update. # We could fetch the whole trial for all trials attached to the experiment, # and therefore combine this call with the one above, but that might be # unnecessarily costly if we're not updating many or any trials. with session_scope() as session: existing_trials = (session.query(trial_sqa_class).filter( trial_sqa_class.id.in_(update_trial_ids)).all()) with session_scope() as session: existing_data = ( session.query(data_sqa_class).filter_by( experiment_id=experiment_id).filter( data_sqa_class.trial_index.in_( update_trial_indices)) # pyre-ignore .all()) trial_id_to_existing_trial = {trial.id: trial for trial in existing_trials} data_id_to_existing_data = {data.id: data for data in existing_data} sqa_trials, sqa_datas = [], [] for trial in trials: sqa_trial, _obj_to_sqa = encoder.trial_to_sqa(trial) obj_to_sqa.extend(_obj_to_sqa) existing_trial = trial_id_to_existing_trial.get(trial._db_id) if existing_trial is None: sqa_trial.experiment_id = experiment_id sqa_trials.append(sqa_trial) else: existing_trial.update(sqa_trial) sqa_trials.append(existing_trial) datas = experiment.data_by_trial.get(trial.index, {}) for ts, data in datas.items(): sqa_data = encoder.data_to_sqa(data=data, trial_index=trial.index, timestamp=ts) obj_to_sqa.append((data, sqa_data)) existing_data = data_id_to_existing_data.get(data._db_id) if existing_data is None: sqa_data.experiment_id = experiment_id sqa_datas.append(sqa_data) else: existing_data.update(sqa_data) sqa_datas.append(existing_data) with session_scope() as session: session.add_all(sqa_trials) session.add_all(sqa_datas) session.flush() _set_db_ids(obj_to_sqa=obj_to_sqa)