def test_bayesopt_batch(parameters, results, transforms): gpyopt = GPyOpt(max_concurrent=10) domain = gpyopt._initialize_domain(parameters, transforms) X, y = GPyOpt._prepare_data_for_bayes_opt(parameters, results, transforms) batch = gpyopt._generate_bayesopt_batch(domain, X, y, lower_is_better=True) assert batch.shape == (10, 5)
def test_reverse_format(parameters, results, transforms): X, y = GPyOpt._prepare_data_for_bayes_opt(parameters, results, transforms) reversed_X = GPyOpt._reverse_to_sherpa_format(X, transforms, parameters) assert reversed_X[0] == { 'dropout': 0.1, 'lr': 1e-3, 'activation': 'tanh', 'num_hidden': 111, 'batch_size': 10 } assert reversed_X[1] == { 'dropout': 0.4, 'lr': 1e-5, 'activation': 'relu', 'num_hidden': 222, 'batch_size': 100 } assert reversed_X[2] == { 'dropout': 0.33, 'lr': 1e-2, 'activation': 'sigmoid', 'num_hidden': 288, 'batch_size': 1000 }
def test_prepare_data_for_bayes_opt(parameters, results): X, y, y_var = GPyOpt._prepare_data_for_bayes_opt(parameters, results) assert numpy.array_equal( X, numpy.array([[0.1, -3., 1, 111], [0.4, -5., 0, 222], [0.33, -2., 2, 288]])) assert numpy.array_equal(y, numpy.array([[0.1], [0.055], [0.15]]))
def test_bayesopt_batch(parameters, results): gpyopt = GPyOpt(max_concurrent=10) gpyopt.domain = gpyopt._initialize_domain(parameters) gpyopt.lower_is_better = True X, y, y_var = GPyOpt._prepare_data_for_bayes_opt(parameters, results) domain = gpyopt._initialize_domain(parameters) batch = gpyopt._generate_bayesopt_batch(X, y, lower_is_better=True, domain=domain) assert batch.shape == (10, 4)
def test_reverse_format(parameters, results): X, y, y_var = GPyOpt._prepare_data_for_bayes_opt(parameters, results) reversed_X = GPyOpt._reverse_to_sherpa_format(X, parameters) assert reversed_X[0] == { 'dropout': 0.1, 'lr': 1e-3, 'activation': 'tanh', 'num_hidden': 111 } assert reversed_X[1] == { 'dropout': 0.4, 'lr': 1e-5, 'activation': 'relu', 'num_hidden': 222 } assert reversed_X[2] == { 'dropout': 0.33, 'lr': 1e-2, 'activation': 'sigmoid', 'num_hidden': 288 }