def _get_prepared_new_trial(self, study_id: int, template_trial: Optional[FrozenTrial], session: orm.Session) -> models.TrialModel: if template_trial is None: trial = models.TrialModel(study_id=study_id, number=None, state=TrialState.RUNNING) else: # Because only `RUNNING` trials can be updated, # we temporarily set the state of the new trial to `RUNNING`. # After all fields of the trial have been updated, # the state is set to `template_trial.state`. temp_state = TrialState.RUNNING trial = models.TrialModel( study_id=study_id, number=None, state=temp_state, value=template_trial.value, datetime_start=template_trial.datetime_start, datetime_complete=template_trial.datetime_complete, ) session.add(trial) # Flush the session cache to reflect the above addition operation to # the current RDB transaction. # # Without flushing, the following operations (e.g, `_set_trial_param_without_commit`) # will fail because the target trial doesn't exist in the storage yet. session.flush() if template_trial is not None: for param_name, param_value in template_trial.params.items(): distribution = template_trial.distributions[param_name] param_value_in_internal_repr = distribution.to_internal_repr( param_value) self._set_trial_param_without_commit( session, trial.trial_id, param_name, param_value_in_internal_repr, distribution) for key, value in template_trial.user_attrs.items(): self._set_trial_user_attr_without_commit( session, trial.trial_id, key, value) for key, value in template_trial.system_attrs.items(): self._set_trial_system_attr_without_commit( session, trial.trial_id, key, value) for step, intermediate_value in template_trial.intermediate_values.items( ): self._set_trial_intermediate_value_without_commit( session, trial.trial_id, step, intermediate_value) trial.state = template_trial.state trial.number = trial.count_past_trials(session) session.add(trial) return trial
def _create_new_trial( self, study_id: int, template_trial: Optional[FrozenTrial] = None) -> FrozenTrial: """Create a new trial and returns its trial_id and a :class:`~optuna.trial.FrozenTrial`. Args: study_id: Study id. template_trial: A :class:`~optuna.trial.FrozenTrial` with default values for trial attributes. Returns: A :class:`~optuna.trial.FrozenTrial` instance. """ session = self.scoped_session() # Ensure that that study exists. models.StudyModel.find_or_raise_by_id(study_id, session) if template_trial is None: trial = models.TrialModel(study_id=study_id, number=None, state=TrialState.RUNNING) else: # Because only `RUNNING` trials can be updated, # we temporarily set the state of the new trial to `RUNNING`. # After all fields of the trial have been updated, # the state is set to `template_trial.state`. temp_state = TrialState.RUNNING trial = models.TrialModel( study_id=study_id, number=None, state=temp_state, value=template_trial.value, datetime_start=template_trial.datetime_start, datetime_complete=template_trial.datetime_complete, ) session.add(trial) # Flush the session cache to reflect the above addition operation to # the current RDB transaction. # # Without flushing, the following operations (e.g, `_set_trial_param_without_commit`) # will fail because the target trial doesn't exist in the storage yet. session.flush() if template_trial is not None: for param_name, param_value in template_trial.params.items(): distribution = template_trial.distributions[param_name] param_value_in_internal_repr = distribution.to_internal_repr( param_value) self._set_trial_param_without_commit( session, trial.trial_id, param_name, param_value_in_internal_repr, distribution) for key, value in template_trial.user_attrs.items(): self._set_trial_user_attr_without_commit( session, trial.trial_id, key, value) for key, value in template_trial.system_attrs.items(): self._set_trial_system_attr_without_commit( session, trial.trial_id, key, value) for step, intermediate_value in template_trial.intermediate_values.items( ): self._set_trial_intermediate_value_without_commit( session, trial.trial_id, step, intermediate_value) trial.state = template_trial.state trial.number = trial.count_past_trials(session) session.add(trial) self._commit(session) if template_trial: frozen = copy.deepcopy(template_trial) frozen.number = trial.number frozen.datetime_start = trial.datetime_start frozen._trial_id = trial.trial_id else: frozen = FrozenTrial( number=trial.number, state=trial.state, value=None, datetime_start=trial.datetime_start, datetime_complete=None, params={}, distributions={}, user_attrs={}, system_attrs={}, intermediate_values={}, trial_id=trial.trial_id, ) return frozen