コード例 #1
0
def test_choice():
    p = choice(partial(int, 15.5), (15, 4), (3, 2))
    assert evaluate(p) == 4
    p = choice(variable('x', value_type=['a', 'b', 'c']), ('a', 'b'),
               ('b', 'c'))
    assert evaluate(p, x='a') == 'b'
    assert evaluate(p, x='b') == 'c'
コード例 #2
0
def test_choice_raises():
    raised = False
    try:
        choice(5, (5, 4), (3, 2, 1))
    except ValueError:
        raised = True
    assert raised

    raised = False
    try:
        choice(5, (5, 4), [3, 4])
    except ValueError:
        raised = True
    assert raised
コード例 #3
0
ファイル: test_pyll.py プロジェクト: memkite/searchspaces
def test_uniform_choice():
    p = as_pyll(choice(variable('foo', value_type=[7, 9, 11]),
                       (7, 'rst'),
                       (9, 'uvw'),
                       (11, 'xyz')))
    assert p.name == 'switch'
    assert p.pos_args[0].name == 'hyperopt_param'
    assert p.pos_args[0].pos_args[0].obj == 'foo'
    assert p.pos_args[0].pos_args[1].name == 'randint'
    assert p.pos_args[0].pos_args[1].arg['upper'].obj == 3
    # Make sure this executes and yields a value in the right domain.
    recursive_set_rng_kwarg(p, np.random)
    try:
        values = [rec_eval(p) for _ in xrange(10)]
    except Exception:
        assert False
    assert all(v in ['rst', 'uvw', 'xyz'] for v in values)
コード例 #4
0
ファイル: test_pyll.py プロジェクト: memkite/searchspaces
def test_nonuniform_choice():
    var = variable('blu', value_type=[2, 4, 8], distribution='categorical',
                   p=[0.2, 0.7, 0.1])
    p = as_pyll(choice(var,
                       (2, 'abc'),
                       (4, 'def'),
                       (8, 'ghi')))
    assert p.name == 'switch'
    assert p.pos_args[0].name == 'hyperopt_param'
    assert p.pos_args[0].pos_args[0].obj == 'blu'
    assert p.pos_args[0].pos_args[1].name == 'categorical'
    assert p.pos_args[0].pos_args[1].arg['p'].name == 'pos_args'
    assert p.pos_args[0].pos_args[1].arg['p'].pos_args[0].obj == 0.2
    assert p.pos_args[0].pos_args[1].arg['p'].pos_args[1].obj == 0.7
    assert p.pos_args[0].pos_args[1].arg['p'].pos_args[2].obj == 0.1
    # Make sure this executes and yields a value in the right domain.
    recursive_set_rng_kwarg(p, np.random)
    try:
        values = [rec_eval(p) for _ in xrange(10)]
    except Exception:
        assert False
    assert all(v in ['abc', 'def', 'ghi'] for v in values)