def test_pruning(storage_mode: str, pruner_init_func: Callable[[], BasePruner], comm: CommunicatorBase) -> None: with MultiNodeStorageSupplier(storage_mode, comm) as storage: pruner = pruner_init_func() study = TestChainerMNStudy._create_shared_study(storage, comm, pruner=pruner) mn_study = ChainerMNStudy(study, comm) def objective(_trial: ChainerMNTrial, _comm: bool) -> float: raise TrialPruned # Always be pruned. # Invoke optimize. n_trials = 20 mn_study.optimize(objective, n_trials=n_trials) # Assert trial count. assert len(mn_study.trials) == n_trials # Assert pruned trial count. pruned_trials = [ t for t in mn_study.trials if t.state == TrialState.PRUNED ] assert len(pruned_trials) == n_trials
def test_relative_sampling(storage_mode: str, comm: CommunicatorBase) -> None: relative_search_space = { "x": distributions.UniformDistribution(low=-10, high=10), "y": distributions.LogUniformDistribution(low=20, high=30), "z": distributions.CategoricalDistribution(choices=(-1.0, 1.0)), } relative_params = {"x": 1.0, "y": 25.0, "z": -1.0} sampler = DeterministicRelativeSampler( relative_search_space, relative_params # type: ignore ) with MultiNodeStorageSupplier(storage_mode, comm) as storage: study = TestChainerMNStudy._create_shared_study(storage, comm, sampler=sampler) mn_study = ChainerMNStudy(study, comm) # Invoke optimize. n_trials = 20 func = Func() mn_study.optimize(func, n_trials=n_trials) # Assert trial counts. assert len(mn_study.trials) == n_trials # Assert the parameters in `relative_params` have been suggested among all nodes. for trial in mn_study.trials: assert trial.params == relative_params
def test_init(storage_mode: str, comm: CommunicatorBase) -> None: with MultiNodeStorageSupplier(storage_mode, comm) as storage: study = TestChainerMNStudy._create_shared_study(storage, comm) mn_study = ChainerMNStudy(study, comm) assert mn_study.study_name == study.study_name
def test_init_with_incompatible_storage(comm: CommunicatorBase) -> None: study = create_study(storage=InMemoryStorage(), study_name="in-memory-study") with pytest.raises(ValueError): ChainerMNStudy(study, comm)
def test_optimize(storage_mode: str, comm: CommunicatorBase) -> None: with MultiNodeStorageSupplier(storage_mode, comm) as storage: study = TestChainerMNStudy._create_shared_study(storage, comm) mn_study = ChainerMNStudy(study, comm) # Invoke optimize. n_trials = 20 func = Func() mn_study.optimize(func, n_trials=n_trials) # Assert trial counts. assert len(mn_study.trials) == n_trials # Assert the same parameters have been suggested among all nodes. for trial in mn_study.trials: assert trial.params == func.suggested_values[trial.number]
def test_init_with_multiple_study_names(storage_mode: str, comm: CommunicatorBase) -> None: TestChainerMNStudy._check_multi_node(comm) with MultiNodeStorageSupplier(storage_mode, comm) as storage: # Create study_name for each rank. name = create_study(storage).study_name study = Study(name, storage) with pytest.raises(ValueError): ChainerMNStudy(study, comm)
def test_failure(storage_mode: str, comm: CommunicatorBase) -> None: with MultiNodeStorageSupplier(storage_mode, comm) as storage: study = TestChainerMNStudy._create_shared_study(storage, comm) mn_study = ChainerMNStudy(study, comm) def objective(_trial: ChainerMNTrial, _comm: bool) -> float: raise ValueError # Always fails. # Invoke optimize in which `ValueError` is accepted. n_trials = 20 mn_study.optimize(objective, n_trials=n_trials, catch=(ValueError,)) # Assert trial count. assert len(mn_study.trials) == n_trials # Assert failed trial count. failed_trials = [t for t in mn_study.trials if t.state == TrialState.FAIL] assert len(failed_trials) == n_trials # Synchronize nodes before executing the next optimization. comm.mpi_comm.barrier() # Invoke optimize in which no exceptions are accepted. with pytest.raises(ValueError): mn_study.optimize(objective, n_trials=n_trials, catch=()) # Assert trial count. assert len(mn_study.trials) == n_trials + 1 # Assert failed trial count. failed_trials = [t for t in mn_study.trials if t.state == TrialState.FAIL] assert len(failed_trials) == n_trials + 1