コード例 #1
0
 def state_dict(self):
     """Return a state dict that can be used to reset the state of the algorithm."""
     return {
         'optimizer_rng_state':
         self.optimizer.rng.get_state(),
         'estimator_rng_state':
         check_random_state(
             self.optimizer.base_estimator_.random_state).get_state()
     }
コード例 #2
0
ファイル: test_space.py プロジェクト: breuleux/orion
    def test_sample(self):
        """Check whether sampling works correctly."""
        seed = 5
        space = Space()
        probs = (0.1, 0.2, 0.3, 0.4)
        categories = ("asdfa", 2, 3, 4)
        dim1 = Categorical("yolo", OrderedDict(zip(categories, probs)), shape=(2, 2))
        space.register(dim1)
        dim2 = Integer("yolo2", "uniform", -3, 6)
        space.register(dim2)
        dim3 = Real("yolo3", "norm", 0.9)
        space.register(dim3)

        point = space.sample(seed=seed)
        rng = check_random_state(seed)
        test_point = [
            dict(
                yolo=dim1.sample(seed=rng)[0],
                yolo2=dim2.sample(seed=rng)[0],
                yolo3=dim3.sample(seed=rng)[0],
            )
        ]
        assert len(point) == len(test_point) == 1
        assert len(point[0].params) == len(test_point[0]) == 3
        assert np.all(point[0].params["yolo"] == test_point[0]["yolo"])
        assert point[0].params["yolo2"] == test_point[0]["yolo2"]
        assert point[0].params["yolo3"] == test_point[0]["yolo3"]

        points = space.sample(2, seed=seed)
        rng = check_random_state(seed)
        points1 = dim1.sample(2, seed=rng)
        points2 = dim2.sample(2, seed=rng)
        points3 = dim3.sample(2, seed=rng)
        test_points = [
            dict(yolo=points1[0], yolo2=points2[0], yolo3=points3[0]),
            dict(yolo=points1[1], yolo2=points2[1], yolo3=points3[1]),
        ]
        assert len(points) == len(test_points) == 2
        for i in range(2):
            assert len(points[i].params) == len(test_points[i]) == 3
            assert np.all(points[i].params["yolo"] == test_points[i]["yolo"])
            assert points[i].params["yolo2"] == test_points[i]["yolo2"]
            assert points[i].params["yolo3"] == test_points[i]["yolo3"]
コード例 #3
0
    def state_dict(self):
        """Return a state dict that can be used to reset the state of the algorithm."""
        state_dict = super(BayesianOptimizer, self).state_dict

        if self.optimizer is None:
            return state_dict

        state_dict.update(
            {
                "optimizer_rng_state": self.optimizer.rng.get_state(),
                "estimator_rng_state": check_random_state(
                    self.optimizer.base_estimator_.random_state
                ).get_state(),
                "Xi": self.optimizer.Xi,
                "yi": self.optimizer.yi,
                # pylint: disable = protected-access
                "_n_initial_points": self.optimizer._n_initial_points,
                "gains_": getattr(self.optimizer, "gains_", None),
                "models": self.optimizer.models,
                "_next_x": getattr(self.optimizer, "_next_x", None),
            }
        )

        return state_dict
コード例 #4
0
ファイル: test_space.py プロジェクト: 5l1v3r1/orion-1
    def test_rng_invalid_value(self):
        """Test that passing int returns RandomState"""
        with pytest.raises(ValueError) as exc:
            check_random_state('oh_no_oh_no')

        assert '\'oh_no_oh_no\' cannot be used to seed' in str(exc.value)
コード例 #5
0
ファイル: test_space.py プロジェクト: 5l1v3r1/orion-1
 def test_rng_tuple(self):
     """Test that passing int returns RandomState"""
     rng = check_random_state((1, 12, 123))
     assert isinstance(rng, np.random.RandomState)
     assert rng is not np.random.mtrand._rand
コード例 #6
0
ファイル: test_space.py プロジェクト: 5l1v3r1/orion-1
 def test_rng_random_state(self):
     """Test that passing RandomState returns itself"""
     rng = np.random.RandomState(1)
     assert check_random_state(rng) is rng
コード例 #7
0
ファイル: test_space.py プロジェクト: 5l1v3r1/orion-1
 def test_rng_null(self):
     """Test that passing None returns numpy._rand"""
     assert check_random_state(None) is np.random.mtrand._rand