def sgd_parsed(L_grad, hypers, parser, callback=None, forward_pass_only=True): x0, alphas, betas, meta = hypers X, V = ExactRep(x0), ExactRep(np.zeros(x0.size)) iters = zip(range(len(alphas)), alphas, betas) for i, alpha, beta in iters: g = L_grad(X.val, meta, i) if callback: callback(X.val, V.val, g, i) cur_alpha_vect = fill_parser(parser, alpha) cur_beta_vect = fill_parser(parser, beta) V.mul(cur_beta_vect).sub(g) X.add(cur_alpha_vect * V.val) x_final = X.val if forward_pass_only: return x_final # Hypergradient calculation def hypergrad(outgrad): d_x = outgrad d_alphas, d_betas = np.zeros(alphas.shape), np.zeros(betas.shape) d_v, d_meta = np.zeros(d_x.shape), np.zeros(meta.shape) grad_proj = lambda x, meta, d, i: np.dot(L_grad(x, meta, i), d) L_hvp_x = grad(grad_proj, 0) L_hvp_meta = grad(grad_proj, 1) for i, alpha, beta in iters[::-1]: # build alpha and beta vector cur_alpha_vect = fill_parser(parser, alpha) cur_beta_vect = fill_parser(parser, beta) for j, (_, (ixs, _)) in enumerate(parser.idxs_and_shapes.iteritems()): d_alphas[i, j] = np.dot(d_x[ixs], V.val[ixs]) # Exactly reverse SGD X.sub(cur_alpha_vect * V.val) g = L_grad(X.val, meta, i) V.add(g).div(cur_beta_vect) d_v += d_x * cur_alpha_vect for j, (_, (ixs, _)) in enumerate(parser.idxs_and_shapes.iteritems()): d_betas[i, j] = np.dot(d_v[ixs], V.val[ixs]) d_x -= L_hvp_x(X.val, meta, d_v, i) d_meta -= L_hvp_meta(X.val, meta, d_v, i) d_v *= cur_beta_vect assert np.all(ExactRep(x0).val == X.val) return d_x, d_alphas, d_betas, d_meta return x_final, [None, hypergrad]
def test_mul_div_with_vector(): """Test if an exact rep can be multiplied and divided elementwise with a vector.""" A = npr.randn(100) B = npr.rand(100) exact_A = ExactRep(A) orig_value = exact_A.val exact_A.mul(B) assert np.allclose(exact_A.val, A*B, rtol=1e-3, atol=1e-4) exact_A.div(B) assert all(exact_A.val == orig_value)
def test_mul_div(): A = npr.randn(100) all_b = [0.95, 0.9, 0.5, 0.3, 1.01] for b in all_b: A_new = (((A * b + A) - A) / b) assert not all(A_new == A) exact_A = ExactRep(A) orig_value = exact_A.val exact_A.mul(b) assert np.allclose(exact_A.val, A * b, rtol=1e-3, atol=1e-4) exact_A.div(b) assert all(exact_A.val == orig_value)
def test_add_sub(): A = npr.randn(100) B = npr.randn(100) * 500 assert np.mean((A + B) - B == A) < 0.5 assert np.mean((A - B) + B == A) < 0.5 exact_A = ExactRep(A) orig_value = exact_A.val exact_A.add(B) assert np.allclose(exact_A.val, A + B) exact_A.sub(B) assert all(exact_A.val == orig_value) exact_A.sub(B) assert np.allclose(exact_A.val, A - B) exact_A.add(B) assert all(exact_A.val == orig_value)
def test_repeated_mul_div(): A = npr.randn(100) exact_A = ExactRep(A) orig_value = exact_A.val all_b = npr.rand(200) A_cur_float = A for b in all_b: A_cur_float = A_cur_float * b exact_A.mul(b) assert np.allclose(exact_A.val, A_cur_float) for b in all_b[::-1]: A_cur_float = A_cur_float / b exact_A.div(b) assert np.mean(A_cur_float == A) < 0.2 assert all(exact_A.val == orig_value)
def hypergrad(outgrad): d_x = outgrad d_alphas, d_betas = np.zeros(alphas.shape), np.zeros(betas.shape) d_v, d_meta = np.zeros(d_x.shape), np.zeros(meta.shape) grad_proj = lambda x, meta, d, i: np.dot(L_grad(x, meta, i), d) L_hvp_x = grad(grad_proj, 0) L_hvp_meta = grad(grad_proj, 1) for i, alpha, beta in iters[::-1]: # build alpha and beta vector cur_alpha_vect = fill_parser(parser, alpha) cur_beta_vect = fill_parser(parser, beta) for j, (_, (ixs, _)) in enumerate(parser.idxs_and_shapes.iteritems()): d_alphas[i, j] = np.dot(d_x[ixs], V.val[ixs]) # Exactly reverse SGD X.sub(cur_alpha_vect * V.val) g = L_grad(X.val, meta, i) V.add(g).div(cur_beta_vect) d_v += d_x * cur_alpha_vect for j, (_, (ixs, _)) in enumerate(parser.idxs_and_shapes.iteritems()): d_betas[i, j] = np.dot(d_v[ixs], V.val[ixs]) d_x -= L_hvp_x(X.val, meta, d_v, i) d_meta -= L_hvp_meta(X.val, meta, d_v, i) d_v *= cur_beta_vect assert np.all(ExactRep(x0).val == X.val) return d_x, d_alphas, d_betas, d_meta
def test_mul_div_with_vector(): """Test if an exact rep can be multiplied and divided elementwise with a vector.""" A = npr.randn(100) B = npr.rand(100) exact_A = ExactRep(A) orig_value = exact_A.val exact_A.mul(B) assert np.allclose(exact_A.val, A * B, rtol=1e-3, atol=1e-4) exact_A.div(B) assert all(exact_A.val == orig_value)