def test_multi_discrete_no_values(): with pytest.raises(ValueError): hp.MultiDiscreteHyperParameter(None, [0, 1]) with pytest.raises(ValueError): hp.MultiDiscreteHyperParameter('h1', None) with pytest.raises(ValueError): hp.MultiDiscreteHyperParameter('h2', [])
def test_multi_discrete(): h1 = hp.MultiDiscreteHyperParameter('h1', [0, 1], sample_count=5) assert h1.name == 'h1' assert h1.num_choices == 2 assert 0 in h1.id2param.values() assert h1.param2id[0] == 0 assert repr(h1) assert h1.sample_count > 0
def test_multi_discrete_sample(): values = [1, 2, 3, 4, 5] h1 = hp.MultiDiscreteHyperParameter('h1', values, sample_count=10) sample = h1.sample() assert sample[0] in values assert len(sample) == 10 samples = np.array([h1.sample() for _ in range(10)]) _, counts = np.unique(samples, return_counts=True) assert np.all(counts > 0)
def test_multi_discrete_encode_decode(): values = [10, 11, 12, 13, 14] h1 = hp.MultiDiscreteHyperParameter('h1', values, sample_count=5, seed=0) sample = h1.sample() encoded = h1.encode(sample) assert encoded == [4, 0, 3, 3, 3] decoded = h1.decode(encoded) for i in range(len(decoded)): assert decoded[i] == values[encoded[i]] # Test for None input values = [None, 1, 2, 3] h2 = hp.MultiDiscreteHyperParameter('h1', values, sample_count=10, seed=0) sample = h2.sample() encoded = h2.encode(sample) assert encoded == [0, 3, 1, 0, 3, 3, 3, 3, 1, 3] decoded = h2.decode(encoded) for i in range(len(decoded)): assert decoded[i] == values[encoded[i]]
def test_multi_discrete_serialization_deserialization(): h1 = hp.MultiDiscreteHyperParameter('h1', [0, 1, None], sample_count=5) config = h1.get_config() assert 'name' in config assert 'values' in config assert 'sample_count' in config values = config['values'] assert len(values) == 3 assert config['sample_count'] == 5 h2 = hp.MultiDiscreteHyperParameter.load_from_config(config) config = h2.get_config() assert 'name' in config assert 'values' in config assert 'sample_count' in config values = config['values'] assert len(values) == 3 assert config['sample_count'] == 5
def get_multi_parameter_list(): h1 = hp.MultiDiscreteHyperParameter('h1', [0, 1, 2], sample_count=2) h2 = hp.MultiDiscreteHyperParameter('h2', [3, 4, 5, 6], sample_count=3) h3 = hp.MultiUniformContinuousHyperParameter('h3', 7, 10, sample_count=5) h4 = hp.MultiDiscreteHyperParameter('h4', ['v1', 'v2'], sample_count=4) return [h1, h2, h3, h4]