def test_generate_with_references():
    defaults = {'a': 0, 'b': 1, 'c': 2, 'd': 3}
    configs = generate(range(10), 'abc', defaults=defaults, add_reference=True)

    variables = list('abc') + ['reference']

    assert list(configs.keys()) == variables

    for name in variables:
        assert len(configs[name]) == 10

    def test_doc(name, i):
        a_doc = copy.copy(defaults)
        if name == 'reference':
            a_doc['_repetition'] = i
        else:
            a_doc[name] = i
        a_doc['_variable'] = name
        a_doc['uid'] = compute_identity(a_doc, 16)
        if name == 'reference':
            a_doc.pop('_repetition')
        a_doc.pop('_variable')
        return a_doc

    for i in range(10):
        for name in 'abc':
            assert configs[name][i] == test_doc(name, i)
示例#2
0
def test_generate():
    defaults = {'a': 0, 'b': 1, 'c': 2, 'd': 3}
    num_experiments = 2
    num_repro = 2
    objective = 'obj'
    variables = list('abc')
    resumable = False
    configs = generate(num_experiments, num_repro, objective, variables,
                       defaults, resumable)

    assert list(configs.keys()) == variables

    for name in 'abc':
        assert len(configs[name]) == num_experiments * num_repro

    def test_doc(name, i, j):
        a_doc = copy.copy(defaults)
        a_doc[name] = int(i)
        a_doc['variable'] = name
        a_doc['repetition'] = j
        a_doc['uid'] = compute_identity(a_doc, 16)
        a_doc.pop('repetition')
        a_doc.pop('variable')
        return a_doc

    for name in 'abc':
        for i in range(num_experiments):
            for j in range(num_repro):
                k = i * num_repro + j
                assert configs[name][k] == test_doc(name, i + 1, j + 1)

    assert k == (num_experiments * num_repro) - 1
def test_register_resume(client):
    defaults = {'a': 0, 'b': 1, 'c': 2, 'd': 3}
    namespace = 'test'
    configs = generate(range(3), 'abc', defaults=defaults)
    variables = list('abc') + ['reference']
    namespaces = [env(namespace, v) for v in variables]

    assert fetch_registered(client, namespaces) == set()
    new_registered = register(client, foo, namespace, configs)
    assert len(fetch_registered(client, namespaces)) == 4 * 3
    assert fetch_registered(client, namespaces) == new_registered

    # Resume with 10 seeds per configs this time.
    configs = generate(range(10), 'abc', defaults=defaults)
    new_registered = register(client, foo, namespace, configs)
    assert len(fetch_registered(client, namespaces)) == 4 * 3 + 4 * 7
    assert fetch_registered(client, namespaces) != new_registered
    assert len(new_registered) == 4 * 7
def test_register_uniques(client):
    defaults = {'a': 1000, 'b': 1001, 'c': 1002, 'd': 3}
    namespace = 'test'
    configs = generate(range(3), 'abc', defaults=defaults)
    variables = list('abc') + ['reference']
    namespaces = [env(namespace, v) for v in variables]

    assert fetch_registered(client, namespaces) == set()
    register(client, foo, namespace, configs)
    assert len(fetch_registered(client, namespaces)) == 4 * 3
def test_fetch_results_non_completed(client):
    defaults = {'a': 0, 'b': 1}
    params = {'c': 2, 'd': 3}
    defaults = {'epoch': 0}
    medians = ['a']
    configs = generate(range(2), 'ab', params)
    namespace = 'test'
    register(client, foo, namespace, configs)

    with pytest.raises(RuntimeError) as exc:
        fetch_results(client, namespace, configs, medians, params, defaults)

    assert exc.match('Not all trials are completed')
def test_fetch_results_corrupt_completed(client):
    defaults = {'a': 0, 'b': 1}
    params = {'c': 2, 'd': 3}
    defaults = {'epoch': 0}
    medians = ['a']
    num_items = 2
    configs = generate(range(num_items), 'ab', defaults=defaults)
    namespace = 'test'
    register(client, foo, namespace, configs)

    for variable in configs.keys():
        for i in range(num_items):
            workitem = client.dequeue(WORK_QUEUE, env(namespace, variable))
            client.mark_actioned(WORK_QUEUE, workitem)

    with pytest.raises(RuntimeError) as exc:
        fetch_results(client, namespace, configs, medians, params, defaults)

    assert exc.match('Nothing found in result queue for trial')
示例#7
0
def test_generate_with_interupts():
    defaults = {'a': 0, 'b': 1, 'c': 2, 'd': 3}
    num_experiments = 10
    num_repro = 10
    objective = 'obj'
    variables = list('abc')
    resumable = True
    configs = generate(num_experiments, num_repro, objective, variables,
                       defaults, resumable)

    assert list(configs.keys()) == variables

    for name in 'abc':
        assert len(configs[name]) == num_experiments * num_repro * 2

    def test_doc(name, i, j, interupt):
        a_doc = copy.copy(defaults)
        a_doc[name] = i
        if interupt:
            a_doc['_interrupt'] = True
        a_doc['variable'] = name
        a_doc['repetition'] = j
        a_doc.pop('uid', None)
        a_doc['uid'] = compute_identity(a_doc, 16)
        a_doc.pop('repetition')
        a_doc.pop('variable')
        return a_doc

    for name in 'abc':
        for i in range(num_experiments):
            for j in range(num_repro):
                k = (i * num_repro + j) * 2
                assert configs[name][k] == test_doc(name,
                                                    i + 1,
                                                    j + 1,
                                                    interupt=False)
                assert configs[name][k + 1] == test_doc(name,
                                                        i + 1,
                                                        j + 1,
                                                        interupt=True)

    assert k == (num_experiments * num_repro * 2) - 2
def build_data(size):

    # TODO: variables is needed to stack the data for diff variables. For the get_median tests
    #        and others.....

    epochs = 5
    defaults = {'a': 0, 'b': 1, 'c': 2, 'd': 3}
    params = {'c': 2, 'd': 3, 'epoch': epochs}
    variables = 'abc'
    configs = generate(range(size),
                       variables,
                       defaults=defaults,
                       add_reference=False)

    n_vars = len(variables)
    n_seeds = size

    objectives = numpy.arange(n_vars * n_seeds * (epochs + 1))
    numpy.random.RandomState(0).shuffle(objectives)
    objectives = objectives.reshape((epochs + 1, n_seeds, n_vars))

    metrics = dict()
    for var_i, (variable, v_configs) in enumerate(configs.items()):
        for seed_i, config in enumerate(v_configs):
            metrics[config['uid']] = [{
                'epoch': i,
                'objective': objectives[i, seed_i, var_i]
            } for i in range(epochs + 1)]

    data = []
    param_names = list(sorted(params.keys()))
    for variable in configs.keys():
        trials = create_trials(configs[variable], params, metrics)
        data.append(
            create_valid_curves_xarray(trials,
                                       metrics,
                                       list(variables),
                                       epochs,
                                       param_names,
                                       seed=variable))

    return xarray.combine_by_coords(data)
def test_generate():
    defaults = {'a': 0, 'b': 1, 'c': 2, 'd': 3}
    configs = generate(range(10),
                       'abc',
                       defaults=defaults,
                       add_reference=False)

    assert list(configs.keys()) == list('abc')

    for name in 'abc':
        assert len(configs[name]) == 10

    def test_doc(name, i):
        a_doc = copy.copy(defaults)
        a_doc[name] = i
        a_doc['_variable'] = name
        a_doc['uid'] = compute_identity(a_doc, 16)
        a_doc.pop('_variable')
        return a_doc

    for i in range(10):
        for name in 'abc':
            assert configs[name][i] == test_doc(name, i)
示例#10
0
def test_fetch_results_all_completed(client):
    defaults = {'a': 1000, 'b': 1001}
    params = {'c': 2, 'd': 3, 'epoch': 5}
    defaults.update(params)
    medians = ['a']
    num_items = 2
    configs = generate(range(num_items), 'ab', defaults=defaults)
    namespace = 'test'
    register(client, foo, namespace, configs)

    print(configs)

    worker = TrialWorker(URI, DATABASE, 0, None)
    worker.max_retry = 0
    worker.timeout = 1
    worker.run()

    print(fetch_vars_stats(client, namespace))

    data = fetch_results(client, namespace, configs, medians, params, defaults)

    assert data.medians == ['a']
    assert data.noise.values.tolist() == ['a', 'b']
    assert data.params.values.tolist() == ['c', 'd']
    assert data.order.values.tolist() == [0, 1]
    assert data.epoch.values.tolist() == list(range(params['epoch'] + 1))
    assert data.uid.shape == (3, 2)
    assert data.seed.values.tolist(
    ) == data.noise.values.tolist() + ['reference']
    assert data.a.values.tolist() == [[0, 1000, 1000], [1, 1000, 1000]]
    assert data.b.values.tolist() == [[1001, 0, 1001], [1001, 1, 1001]]
    assert data.c.values.tolist() == [[2, 2, 2], [2, 2, 2]]
    assert data.d.values.tolist() == [[3, 3, 3], [3, 3, 3]]

    assert (data.obj.loc[dict(order=0, seed='a')].values.tolist() == list(
        range(2002, 2002 + params['epoch'] + 1)))
示例#11
0
def test_remaining(client):
    defaults = {'a': 0, 'b': 1, 'c': 2, 'd': 3}
    configs = generate(range(1), 'ab', defaults=defaults)
    namespace = 'test'
    register(client, foo, namespace, configs)

    def get_stats(variables):
        trial_stats = fetch_vars_stats(client, namespace)
        print(trial_stats)
        return {v: trial_stats[f'test-{v}'] for v in variables}

    assert remaining(get_stats('ab'))

    assert remaining(get_stats('a'))
    workitem = client.dequeue(WORK_QUEUE, env(namespace, 'a'))
    assert remaining(get_stats('a'))
    client.mark_actioned(WORK_QUEUE, workitem)
    assert not remaining(get_stats('a'))

    assert remaining(get_stats('ab'))
    workitem = client.dequeue(WORK_QUEUE, env(namespace, 'b'))
    assert remaining(get_stats('ab'))
    client.mark_actioned(WORK_QUEUE, workitem)
    assert not remaining(get_stats('ab'))