def test_cascade_delete_on_trial(session: Session) -> None: trial = TestTrialHeartbeatModel._create_model(session) session.commit() assert TrialHeartbeatModel.where_trial_id(trial.trial_id, session) is not None session.delete(trial) session.commit() assert TrialHeartbeatModel.where_trial_id(trial.trial_id, session) is None
def test_record_heartbeat() -> None: heartbeat_interval = 1 n_trials = 2 sleep_sec = 2 def objective(trial: Trial) -> float: time.sleep(sleep_sec) return 1.0 with StorageSupplier("sqlite") as storage: assert isinstance(storage, RDBStorage) storage.heartbeat_interval = heartbeat_interval study = create_study(storage=storage) # Exceptions raised in spawned threads are caught by `_TestableThread`. with patch("optuna.study._optimize.Thread", _TestableThread): study.optimize(objective, n_trials=n_trials) trial_heartbeats = [] with _create_scoped_session(storage.scoped_session) as session: trial_ids = [trial._trial_id for trial in study.trials] for trial_id in trial_ids: heartbeat_model = TrialHeartbeatModel.where_trial_id( trial_id, session) assert heartbeat_model is not None trial_heartbeats.append(heartbeat_model.heartbeat) assert len(trial_heartbeats) == n_trials for i in range(n_trials - 1): assert (trial_heartbeats[i + 1] - trial_heartbeats[i]).seconds - sleep_sec <= 1
def test_where_trial_id(session: Session) -> None: trial = TestTrialHeartbeatModel._create_model(session) trial_heartbeat = TrialHeartbeatModel.where_trial_id( trial.trial_id, session) assert trial_heartbeat is not None assert isinstance(trial_heartbeat.heartbeat, datetime)
def _create_model(session: Session) -> TrialModel: direction = StudyDirectionModel(direction=StudyDirection.MINIMIZE, objective=0) study = StudyModel(study_id=1, study_name="test-study", directions=[direction]) trial = TrialModel(trial_id=1, study_id=study.study_id, state=TrialState.COMPLETE) session.add(study) session.add(trial) session.add(TrialHeartbeatModel(trial_id=trial.trial_id)) session.commit() return trial