Ejemplo n.º 1
0
def test_l1_reg_oracle():
    # h(x) = 1.0 * \|x\|_1
    oracle = oracles.L1RegOracle(1.0)

    # Checks at point x = [0, 0, 0]
    x = np.zeros(3)
    assert_almost_equal(oracle.func(x), 0.0)
    ok_(np.allclose(oracle.prox(x, alpha=1.0), x))
    ok_(np.allclose(oracle.prox(x, alpha=2.0), x))
    ok_(isinstance(oracle.prox(x, alpha=1.0), np.ndarray))

    # Checks at point x = [-3]
    x = np.array([-3.0])
    assert_almost_equal(oracle.func(x), 3.0)
    print(oracle.prox(x, alpha=1.0))
    ok_(np.allclose(oracle.prox(x, alpha=1.0), np.array([-2.0])))
    ok_(np.allclose(oracle.prox(x, alpha=2.0), np.array([-1.0])))
    ok_(isinstance(oracle.prox(x, alpha=1.0), np.ndarray))

    # Checks at point x = [-3, 3]
    x = np.array([-3.0, 3.0])
    assert_almost_equal(oracle.func(x), 6.0)
    ok_(np.allclose(oracle.prox(x, alpha=1.0), np.array([-2.0, 2.0])))
    ok_(np.allclose(oracle.prox(x, alpha=2.0), np.array([-1.0, 1.0])))
    ok_(isinstance(oracle.prox(x, alpha=1.0), np.ndarray))
def test_l1_reg_oracle_2():
    # h(x) = 2.0 * \|x\|_1
    oracle = oracles.L1RegOracle(2.0)

    # Checks at point x = [-3, 3]
    x = np.array([-3.0, 3.0])
    assert_almost_equal(oracle.func(x), 6 * 2.0)
    ok_(np.allclose(oracle.prox(x, alpha=1.0), np.array([-1.0, 1.0])))