def test_should_prune(storage_mode: str, is_pruning: bool) -> None: with StorageSupplier(storage_mode) as storage: if dist.get_rank() == 0: # type: ignore study = optuna.create_study(storage=storage, pruner=DeterministicPruner(is_pruning)) trial = TorchDistributedTrial(study.ask()) else: trial = TorchDistributedTrial(None) trial.report(1, 0) assert trial.should_prune() == is_pruning
def test_report(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: study: Optional[optuna.study.Study] = None if dist.get_rank() == 0: study = optuna.create_study(storage=storage) trial = TorchDistributedTrial(study.ask()) else: trial = TorchDistributedTrial(None) trial.report(1, 0) if dist.get_rank() == 0: assert study is not None study.trials[0].intermediate_values[0] == 1
def test_report_nan(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: study: Optional[optuna.study.Study] = None if dist.get_rank() == 0: study = optuna.create_study(storage=storage) trial = TorchDistributedTrial(study.ask()) else: trial = TorchDistributedTrial(None) with pytest.raises(TypeError): trial.report("abc", 0) # type: ignore if dist.get_rank() == 0: assert study is not None assert len(study.trials[0].intermediate_values) == 0