예제 #1
0
def test_get_hpo_completed(client):

    register_hpo(client, NAMESPACE, foo, CONFIG, {'e': 2})

    worker = TrialWorker(URI, DATABASE, 0, NAMESPACE)
    worker.max_retry = 0
    worker.run()

    hpo, remote_call = get_hpo(client, NAMESPACE)

    assert len(hpo.trials) == 1
    state_dict = hpo.state_dict(compressed=False)
    assert state_dict['seed'] == CONFIG['seed']
    assert state_dict['fidelity'] == CONFIG['fidelity']
    state_dict['space'].pop('uid')
    assert state_dict['space'] == CONFIG['space']

    # Verify default was passed properly
    assert remote_call['kwargs']['e'] == 2
    remote_call['kwargs'].update(dict(a=1, b=1, c=1, d=1, uid=0,
                                      client=client))

    # Verify that the remote_call is indeed callable.
    a = 1
    b = 1
    c = 1
    d = 1
    e = 2
    assert exec_remote_call(remote_call) == a + 2 * b - c**2 + d + e
예제 #2
0
def test_get_hpo_non_completed(client):

    register_hpo(client, NAMESPACE, foo, CONFIG, DEFAULTS)
    with pytest.raises(RuntimeError) as exc:
        get_hpo(client, NAMESPACE)

    exc.match(f'No HPO for namespace {NAMESPACE} or HPO is not completed')
예제 #3
0
def test_is_hpo_completed(client):

    assert not is_hpo_completed(client, NAMESPACE)

    register_hpo(client, NAMESPACE, foo, CONFIG, DEFAULTS)

    assert not is_hpo_completed(client, NAMESPACE)

    worker = TrialWorker(URI, DATABASE, 0, NAMESPACE)
    worker.max_retry = 0
    worker.run()

    assert is_hpo_completed(client, NAMESPACE)
예제 #4
0
def test_plot(client):
    config = copy.deepcopy(CONFIG)
    num_trials = 10
    config['count'] = num_trials
    config['fidelity'] = Fidelity(1, 1, name='epoch').to_dict()

    register_hpo(client, NAMESPACE, foo, config, {'e': 2})
    worker = TrialWorker(URI, DATABASE, 0, NAMESPACE)
    worker.max_retry = 0
    worker.run()

    data = fetch_hpo_valid_curves(client, NAMESPACE, ['e'])

    plot(config['space'], 'obj', data, 'test.png')
예제 #5
0
def test_save_load_results(client):
    config = copy.deepcopy(CONFIG)
    num_trials = 2
    config['count'] = num_trials
    config['fidelity'] = Fidelity(1, 1, name='epoch').to_dict()

    register_hpo(client, NAMESPACE, foo, config, {'e': 2})
    worker = TrialWorker(URI, DATABASE, 0, NAMESPACE)
    worker.max_retry = 0
    worker.run()

    data = fetch_hpo_valid_curves(client, NAMESPACE, ['e'])

    save_results(NAMESPACE, data, '.')

    assert load_results(NAMESPACE, '.')
예제 #6
0
def test_fetch_hpo_valid_results_no_epochs(client):
    config = copy.deepcopy(CONFIG)
    num_trials = 5
    config['count'] = num_trials
    config['fidelity'] = Fidelity(1, 1, name='epoch').to_dict()

    register_hpo(client, NAMESPACE, foo, config, {'e': 2})
    worker = TrialWorker(URI, DATABASE, 0, NAMESPACE)
    worker.max_retry = 0
    worker.run()

    data = fetch_hpo_valid_curves(client, NAMESPACE, ['e'])

    assert data.attrs['namespace'] == NAMESPACE
    assert data.epoch.values.tolist() == [0, 1]
    assert data.order.values.tolist() == list(range(num_trials))
    assert data.seed.values.tolist() == [1]
    assert data.params.values.tolist() == list('abcd')
    assert data.noise.values.tolist() == ['e']
    assert data.obj.shape == (2, num_trials, 1)
    assert data.valid.shape == (2, num_trials, 1)
예제 #7
0
def test_convert_xarray_to_scipy_results(client):
    config = copy.deepcopy(CONFIG)
    num_trials = 10
    config['count'] = num_trials
    config['fidelity'] = Fidelity(1, 1, name='epoch').to_dict()

    register_hpo(client, NAMESPACE, foo, config, {'e': 2})
    worker = TrialWorker(URI, DATABASE, 0, NAMESPACE)
    worker.max_retry = 0
    worker.run()

    data = fetch_hpo_valid_curves(client, NAMESPACE, ['e'])

    scipy_results = xarray_to_scipy_results(config['space'], 'obj', data)

    min_idx = numpy.argmin(data.obj.values[1, :, 0])

    assert scipy_results.x[0] == data.a.values[min_idx, 0]
    assert scipy_results.x[1] == data.b.values[min_idx, 0]
    assert scipy_results.x[2] == data.c.values[min_idx, 0]
    assert scipy_results.x[3] == numpy.log(data.d.values[min_idx, 0])
    assert scipy_results.fun == data.obj.values[1, min_idx, 0]
    assert len(scipy_results.x_iters) == num_trials
예제 #8
0
def test_fetch_hpo_valid_results(client):
    config = copy.deepcopy(CONFIG)
    num_trials = 5
    config['count'] = num_trials

    register_hpo(client, NAMESPACE, foo, config, {'e': 2})
    worker = TrialWorker(URI, DATABASE, 0, NAMESPACE)
    worker.max_retry = 0
    worker.run()

    data = fetch_hpo_valid_curves(client, NAMESPACE, ['e'])

    assert data.attrs['namespace'] == NAMESPACE
    assert data.epoch.values.tolist() == list(
        range(config['fidelity']['max'] + 1))
    assert data.order.values.tolist() == list(range(num_trials))
    assert data.seed.values.tolist() == [1]
    assert data.params.values.tolist() == list('abcd')
    assert data.noise.values.tolist() == ['e']
    assert data.obj.shape == (config['fidelity']['max'] + 1, num_trials, 1)
    assert numpy.all(
        (data.obj.loc[dict(epoch=10)] -
         data.obj.loc[dict(epoch=0)]) == (numpy.ones((num_trials, 1)) * 10))
예제 #9
0
def test_register_hpo_is_actionable(client):
    """Test that the registered HPO have valid workitems and can be executed."""
    namespace = 'test-hpo'
    config = {
        'name': 'random_search',
        'seed': 1,
        'count': 1,
        'fidelity': Fidelity(1, 10, name='d').to_dict(),
        'space': {
            'a': 'uniform(-1, 1)',
            'b': 'uniform(-1, 1)',
            'c': 'uniform(-1, 1)',
            'd': 'uniform(-1, 1)'
        }
    }

    defaults = {}
    register_hpo(client, namespace, foo, config, defaults)
    worker = TrialWorker(URI, DATABASE, 0, namespace)
    worker.max_retry = 0
    worker.run()

    assert client.monitor().read_count(WORK_QUEUE, namespace,
                                       mtype=WORK_ITEM) == 1
    assert client.monitor().read_count(WORK_QUEUE, namespace,
                                       mtype=HPO_ITEM) == 2

    messages = client.monitor().unread_messages(RESULT_QUEUE,
                                                namespace,
                                                mtype=HPO_ITEM)

    compressed_state = messages[0].message.get('hpo_state')
    assert compressed_state is not None
    state = decompress_dict(compressed_state)

    assert len(state['trials']) == 1
    assert state['trials'][0][1]['objectives'] == [10.715799430116764]
예제 #10
0
def test_is_registered(client):
    assert not is_registered(client, NAMESPACE)

    register_hpo(client, NAMESPACE, foo, CONFIG, DEFAULTS)

    assert is_registered(client, NAMESPACE)