Beispiel #1
0
    def test_cascade_delete_on_trial(session: Session) -> None:

        trial = TestTrialHeartbeatModel._create_model(session)
        session.commit()

        assert TrialHeartbeatModel.where_trial_id(trial.trial_id, session) is not None

        session.delete(trial)
        session.commit()

        assert TrialHeartbeatModel.where_trial_id(trial.trial_id, session) is None
Beispiel #2
0
def test_record_heartbeat() -> None:

    heartbeat_interval = 1
    n_trials = 2
    sleep_sec = 2

    def objective(trial: Trial) -> float:
        time.sleep(sleep_sec)
        return 1.0

    with StorageSupplier("sqlite") as storage:
        assert isinstance(storage, RDBStorage)
        storage.heartbeat_interval = heartbeat_interval
        study = create_study(storage=storage)
        # Exceptions raised in spawned threads are caught by `_TestableThread`.
        with patch("optuna.study._optimize.Thread", _TestableThread):
            study.optimize(objective, n_trials=n_trials)

        trial_heartbeats = []

        with _create_scoped_session(storage.scoped_session) as session:
            trial_ids = [trial._trial_id for trial in study.trials]
            for trial_id in trial_ids:
                heartbeat_model = TrialHeartbeatModel.where_trial_id(
                    trial_id, session)
                assert heartbeat_model is not None
                trial_heartbeats.append(heartbeat_model.heartbeat)

        assert len(trial_heartbeats) == n_trials
        for i in range(n_trials - 1):
            assert (trial_heartbeats[i + 1] -
                    trial_heartbeats[i]).seconds - sleep_sec <= 1
Beispiel #3
0
    def test_where_trial_id(session: Session) -> None:

        trial = TestTrialHeartbeatModel._create_model(session)
        trial_heartbeat = TrialHeartbeatModel.where_trial_id(
            trial.trial_id, session)
        assert trial_heartbeat is not None
        assert isinstance(trial_heartbeat.heartbeat, datetime)