def test_get_hpo_completed(client): register_hpo(client, NAMESPACE, foo, CONFIG, {'e': 2}) worker = TrialWorker(URI, DATABASE, 0, NAMESPACE) worker.max_retry = 0 worker.run() hpo, remote_call = get_hpo(client, NAMESPACE) assert len(hpo.trials) == 1 state_dict = hpo.state_dict(compressed=False) assert state_dict['seed'] == CONFIG['seed'] assert state_dict['fidelity'] == CONFIG['fidelity'] state_dict['space'].pop('uid') assert state_dict['space'] == CONFIG['space'] # Verify default was passed properly assert remote_call['kwargs']['e'] == 2 remote_call['kwargs'].update(dict(a=1, b=1, c=1, d=1, uid=0, client=client)) # Verify that the remote_call is indeed callable. a = 1 b = 1 c = 1 d = 1 e = 2 assert exec_remote_call(remote_call) == a + 2 * b - c**2 + d + e
def test_get_hpo_non_completed(client): register_hpo(client, NAMESPACE, foo, CONFIG, DEFAULTS) with pytest.raises(RuntimeError) as exc: get_hpo(client, NAMESPACE) exc.match(f'No HPO for namespace {NAMESPACE} or HPO is not completed')
def test_is_hpo_completed(client): assert not is_hpo_completed(client, NAMESPACE) register_hpo(client, NAMESPACE, foo, CONFIG, DEFAULTS) assert not is_hpo_completed(client, NAMESPACE) worker = TrialWorker(URI, DATABASE, 0, NAMESPACE) worker.max_retry = 0 worker.run() assert is_hpo_completed(client, NAMESPACE)
def test_plot(client): config = copy.deepcopy(CONFIG) num_trials = 10 config['count'] = num_trials config['fidelity'] = Fidelity(1, 1, name='epoch').to_dict() register_hpo(client, NAMESPACE, foo, config, {'e': 2}) worker = TrialWorker(URI, DATABASE, 0, NAMESPACE) worker.max_retry = 0 worker.run() data = fetch_hpo_valid_curves(client, NAMESPACE, ['e']) plot(config['space'], 'obj', data, 'test.png')
def test_save_load_results(client): config = copy.deepcopy(CONFIG) num_trials = 2 config['count'] = num_trials config['fidelity'] = Fidelity(1, 1, name='epoch').to_dict() register_hpo(client, NAMESPACE, foo, config, {'e': 2}) worker = TrialWorker(URI, DATABASE, 0, NAMESPACE) worker.max_retry = 0 worker.run() data = fetch_hpo_valid_curves(client, NAMESPACE, ['e']) save_results(NAMESPACE, data, '.') assert load_results(NAMESPACE, '.')
def test_fetch_hpo_valid_results_no_epochs(client): config = copy.deepcopy(CONFIG) num_trials = 5 config['count'] = num_trials config['fidelity'] = Fidelity(1, 1, name='epoch').to_dict() register_hpo(client, NAMESPACE, foo, config, {'e': 2}) worker = TrialWorker(URI, DATABASE, 0, NAMESPACE) worker.max_retry = 0 worker.run() data = fetch_hpo_valid_curves(client, NAMESPACE, ['e']) assert data.attrs['namespace'] == NAMESPACE assert data.epoch.values.tolist() == [0, 1] assert data.order.values.tolist() == list(range(num_trials)) assert data.seed.values.tolist() == [1] assert data.params.values.tolist() == list('abcd') assert data.noise.values.tolist() == ['e'] assert data.obj.shape == (2, num_trials, 1) assert data.valid.shape == (2, num_trials, 1)
def test_convert_xarray_to_scipy_results(client): config = copy.deepcopy(CONFIG) num_trials = 10 config['count'] = num_trials config['fidelity'] = Fidelity(1, 1, name='epoch').to_dict() register_hpo(client, NAMESPACE, foo, config, {'e': 2}) worker = TrialWorker(URI, DATABASE, 0, NAMESPACE) worker.max_retry = 0 worker.run() data = fetch_hpo_valid_curves(client, NAMESPACE, ['e']) scipy_results = xarray_to_scipy_results(config['space'], 'obj', data) min_idx = numpy.argmin(data.obj.values[1, :, 0]) assert scipy_results.x[0] == data.a.values[min_idx, 0] assert scipy_results.x[1] == data.b.values[min_idx, 0] assert scipy_results.x[2] == data.c.values[min_idx, 0] assert scipy_results.x[3] == numpy.log(data.d.values[min_idx, 0]) assert scipy_results.fun == data.obj.values[1, min_idx, 0] assert len(scipy_results.x_iters) == num_trials
def test_fetch_hpo_valid_results(client): config = copy.deepcopy(CONFIG) num_trials = 5 config['count'] = num_trials register_hpo(client, NAMESPACE, foo, config, {'e': 2}) worker = TrialWorker(URI, DATABASE, 0, NAMESPACE) worker.max_retry = 0 worker.run() data = fetch_hpo_valid_curves(client, NAMESPACE, ['e']) assert data.attrs['namespace'] == NAMESPACE assert data.epoch.values.tolist() == list( range(config['fidelity']['max'] + 1)) assert data.order.values.tolist() == list(range(num_trials)) assert data.seed.values.tolist() == [1] assert data.params.values.tolist() == list('abcd') assert data.noise.values.tolist() == ['e'] assert data.obj.shape == (config['fidelity']['max'] + 1, num_trials, 1) assert numpy.all( (data.obj.loc[dict(epoch=10)] - data.obj.loc[dict(epoch=0)]) == (numpy.ones((num_trials, 1)) * 10))
def test_register_hpo_is_actionable(client): """Test that the registered HPO have valid workitems and can be executed.""" namespace = 'test-hpo' config = { 'name': 'random_search', 'seed': 1, 'count': 1, 'fidelity': Fidelity(1, 10, name='d').to_dict(), 'space': { 'a': 'uniform(-1, 1)', 'b': 'uniform(-1, 1)', 'c': 'uniform(-1, 1)', 'd': 'uniform(-1, 1)' } } defaults = {} register_hpo(client, namespace, foo, config, defaults) worker = TrialWorker(URI, DATABASE, 0, namespace) worker.max_retry = 0 worker.run() assert client.monitor().read_count(WORK_QUEUE, namespace, mtype=WORK_ITEM) == 1 assert client.monitor().read_count(WORK_QUEUE, namespace, mtype=HPO_ITEM) == 2 messages = client.monitor().unread_messages(RESULT_QUEUE, namespace, mtype=HPO_ITEM) compressed_state = messages[0].message.get('hpo_state') assert compressed_state is not None state = decompress_dict(compressed_state) assert len(state['trials']) == 1 assert state['trials'][0][1]['objectives'] == [10.715799430116764]
def test_is_registered(client): assert not is_registered(client, NAMESPACE) register_hpo(client, NAMESPACE, foo, CONFIG, DEFAULTS) assert is_registered(client, NAMESPACE)