def test_system_attrs_with_exception() -> None: with StorageSupplier("sqlite") as storage: if dist.get_rank() == 0: study = optuna.create_study(storage=storage) trial = TorchDistributedTrial(study.ask()) else: trial = TorchDistributedTrial(None) with pytest.raises(TypeError): trial.set_system_attr("not serializable", torch.Tensor([1, 2]))
def test_system_attrs(storage_mode: str) -> None: with StorageSupplier(storage_mode) as storage: if dist.get_rank() == 0: study = optuna.create_study(storage=storage) trial = TorchDistributedTrial(study.ask()) else: trial = TorchDistributedTrial(None) trial.set_system_attr("dataset", "mnist") trial.set_system_attr("batch_size", 128) assert trial.system_attrs["dataset"] == "mnist" assert trial.system_attrs["batch_size"] == 128