def test_plot(client):
    config = copy.deepcopy(CONFIG)
    num_trials = 10
    config['count'] = num_trials
    config['fidelity'] = Fidelity(1, 1, name='epoch').to_dict()

    register_hpo(client, NAMESPACE, foo, config, {'e': 2})
    worker = TrialWorker(URI, DATABASE, 0, NAMESPACE)
    worker.max_retry = 0
    worker.run()

    data = fetch_hpo_valid_curves(client, NAMESPACE, ['e'])

    plot(config['space'], 'obj', data, 'test.png')
Beispiel #2
0
def test_save_load_results(client):
    config = copy.deepcopy(CONFIG)
    num_trials = 2
    config['count'] = num_trials
    config['fidelity'] = Fidelity(1, 1, name='epoch').to_dict()

    register_hpo(client, NAMESPACE, foo, config, {'e': 2})
    worker = TrialWorker(URI, DATABASE, 0, NAMESPACE)
    worker.max_retry = 0
    worker.run()

    data = fetch_hpo_valid_curves(client, NAMESPACE, ['e'])

    save_results(NAMESPACE, data, '.')

    assert load_results(NAMESPACE, '.')
Beispiel #3
0
def test_fetch_hpo_valid_results_no_epochs(client):
    config = copy.deepcopy(CONFIG)
    num_trials = 5
    config['count'] = num_trials
    config['fidelity'] = Fidelity(1, 1, name='epoch').to_dict()

    register_hpo(client, NAMESPACE, foo, config, {'e': 2})
    worker = TrialWorker(URI, DATABASE, 0, NAMESPACE)
    worker.max_retry = 0
    worker.run()

    data = fetch_hpo_valid_curves(client, NAMESPACE, ['e'])

    assert data.attrs['namespace'] == NAMESPACE
    assert data.epoch.values.tolist() == [0, 1]
    assert data.order.values.tolist() == list(range(num_trials))
    assert data.seed.values.tolist() == [1]
    assert data.params.values.tolist() == list('abcd')
    assert data.noise.values.tolist() == ['e']
    assert data.obj.shape == (2, num_trials, 1)
    assert data.valid.shape == (2, num_trials, 1)
def test_convert_xarray_to_scipy_results(client):
    config = copy.deepcopy(CONFIG)
    num_trials = 10
    config['count'] = num_trials
    config['fidelity'] = Fidelity(1, 1, name='epoch').to_dict()

    register_hpo(client, NAMESPACE, foo, config, {'e': 2})
    worker = TrialWorker(URI, DATABASE, 0, NAMESPACE)
    worker.max_retry = 0
    worker.run()

    data = fetch_hpo_valid_curves(client, NAMESPACE, ['e'])

    scipy_results = xarray_to_scipy_results(config['space'], 'obj', data)

    min_idx = numpy.argmin(data.obj.values[1, :, 0])

    assert scipy_results.x[0] == data.a.values[min_idx, 0]
    assert scipy_results.x[1] == data.b.values[min_idx, 0]
    assert scipy_results.x[2] == data.c.values[min_idx, 0]
    assert scipy_results.x[3] == numpy.log(data.d.values[min_idx, 0])
    assert scipy_results.fun == data.obj.values[1, min_idx, 0]
    assert len(scipy_results.x_iters) == num_trials
Beispiel #5
0
def test_fetch_hpo_valid_results(client):
    config = copy.deepcopy(CONFIG)
    num_trials = 5
    config['count'] = num_trials

    register_hpo(client, NAMESPACE, foo, config, {'e': 2})
    worker = TrialWorker(URI, DATABASE, 0, NAMESPACE)
    worker.max_retry = 0
    worker.run()

    data = fetch_hpo_valid_curves(client, NAMESPACE, ['e'])

    assert data.attrs['namespace'] == NAMESPACE
    assert data.epoch.values.tolist() == list(
        range(config['fidelity']['max'] + 1))
    assert data.order.values.tolist() == list(range(num_trials))
    assert data.seed.values.tolist() == [1]
    assert data.params.values.tolist() == list('abcd')
    assert data.noise.values.tolist() == ['e']
    assert data.obj.shape == (config['fidelity']['max'] + 1, num_trials, 1)
    assert numpy.all(
        (data.obj.loc[dict(epoch=10)] -
         data.obj.loc[dict(epoch=0)]) == (numpy.ones((num_trials, 1)) * 10))
Beispiel #6
0
def fetch_hpos_valid_curves(client, namespaces, variables, data, partial=False):
    hpos_ready = defaultdict(list)
    remainings = defaultdict(list)
    fetched_one = False
    for hpo in namespaces.keys():
        for hpo_namespace in namespaces[hpo]:
            if (is_hpo_completed(client, hpo_namespace) and not fetched_one) or partial:
                print(f'Fetching results of {hpo_namespace}')

                hpo_data = fetch_hpo_valid_curves(
                    client, hpo_namespace, variables, partial=partial)
                fetched_one = True
                if hpo_data:
                    data[hpo][hpo_namespace] = hpo_data
                    hpos_ready[hpo].append(hpo_namespace)
                elif partial:
                    print(f'No metrics available for {hpo_namespace}')
                else:
                    raise RuntimeError(
                        f'{hpo_namespace} is completed but no metrics are available!?')
            else:
                remainings[hpo].append(hpo_namespace)

    return hpos_ready, remainings