Example #1
0
def test_split_by_type():
    domain = Domain({"x": [1, 2], "y": {-3, 2, 5}, "z": {"small", 1, 0.1}})
    discr, cat, cont = domain.split_by_type()
    assert sum(domain.split_by_type(), Domain({})) == domain
    assert discr == Domain({"y": {-3, 2, 5}})
    assert cat == Domain({"z": {"small", 1, 0.1}})
    assert cont == Domain({"x": [1, 2]})
Example #2
0
    def __init__(self,
                 domain: Domain,
                 sample_continuous: bool = False,
                 seed: int = None):
        """Initialise the :class:`GridSearch` optimiser from a discrete domain.

        If the domain contains continuous subspaces, then they could be sampled
        if `sample_continuous` is enabled.

        Args:
            domain: :class:`Domain`. The domain to iterate over.
            sample_continuous: (optional) :obj:`bool`. Whether to sample the
                continuous subspaces of the domain.
            seed: (optional) :obj:`int`. Seed for the sampling of the continuous
                subspace if necessary.
        """
        if domain.is_continuous and not sample_continuous:
            raise DomainNotIterableError(
                "Cannot perform grid search on (partially) continuous domain. "
                "To enable grid search in this case, set the argument "
                "'sample_continuous' to True.")
        super(GridSearch, self).__init__(domain)
        (discrete_domain, categorical_domain,
         continuous_domain) = domain.split_by_type()
        # unify the discrete and the categorical into one,
        # as they can be iterated:
        self.discrete_domain = discrete_domain + categorical_domain
        if seed is not None:
            self.continuous_domain = Domain(continuous_domain.as_dict(),
                                            seed=seed)
        else:
            self.continuous_domain = continuous_domain
        self._discrete_domain_iter = iter(self.discrete_domain)
        self._is_exhausted = len(self.discrete_domain) == 0
        self.__exhausted_err = ExhaustedSearchSpaceError(
            "The domain has been exhausted. Reset the optimiser to start again."
        )