def test_l1_reg_oracle(): # h(x) = 1.0 * \|x\|_1 oracle = oracles.L1RegOracle(1.0) # Checks at point x = [0, 0, 0] x = np.zeros(3) assert_almost_equal(oracle.func(x), 0.0) ok_(np.allclose(oracle.prox(x, alpha=1.0), x)) ok_(np.allclose(oracle.prox(x, alpha=2.0), x)) ok_(isinstance(oracle.prox(x, alpha=1.0), np.ndarray)) # Checks at point x = [-3] x = np.array([-3.0]) assert_almost_equal(oracle.func(x), 3.0) print(oracle.prox(x, alpha=1.0)) ok_(np.allclose(oracle.prox(x, alpha=1.0), np.array([-2.0]))) ok_(np.allclose(oracle.prox(x, alpha=2.0), np.array([-1.0]))) ok_(isinstance(oracle.prox(x, alpha=1.0), np.ndarray)) # Checks at point x = [-3, 3] x = np.array([-3.0, 3.0]) assert_almost_equal(oracle.func(x), 6.0) ok_(np.allclose(oracle.prox(x, alpha=1.0), np.array([-2.0, 2.0]))) ok_(np.allclose(oracle.prox(x, alpha=2.0), np.array([-1.0, 1.0]))) ok_(isinstance(oracle.prox(x, alpha=1.0), np.ndarray))
def test_l1_reg_oracle_2(): # h(x) = 2.0 * \|x\|_1 oracle = oracles.L1RegOracle(2.0) # Checks at point x = [-3, 3] x = np.array([-3.0, 3.0]) assert_almost_equal(oracle.func(x), 6 * 2.0) ok_(np.allclose(oracle.prox(x, alpha=1.0), np.array([-1.0, 1.0])))