def time_vspace_flatten(): val = {'k': npr.random((4, 4)), 'k2': npr.random((3, 3)), 'k3': 3.0, 'k4': [1.0, 4.0, 7.0, 9.0], 'k5': np.array([4., 5., 6.]), 'k6': np.array([[7., 8.], [9., 10.]])} vspace_flatten(val)
def check_equivalent(A, B, rtol=RTOL, atol=ATOL): A_vspace = vspace(A) B_vspace = vspace(B) A_flat = vspace_flatten(A) B_flat = vspace_flatten(B) assert A_vspace == B_vspace, \ "VSpace mismatch:\nanalytic: {}\nnumeric: {}".format(A_vspace, B_vspace) assert np.allclose(vspace_flatten(A), vspace_flatten(B), rtol=rtol, atol=atol), \ "Diffs are:\n{}.\nanalytic is:\n{}.\nnumeric is:\n{}.".format( A_flat - B_flat, A_flat, B_flat)
def time_vspace_flatten(): val = { 'k': npr.random((4, 4)), 'k2': npr.random((3, 3)), 'k3': 3.0, 'k4': [1.0, 4.0, 7.0, 9.0], 'k5': np.array([4., 5., 6.]), 'k6': np.array([[7., 8.], [9., 10.]]) } vspace_flatten(val)
def check_args(fun, argnum, args, kwargs): ans = fun(*args) in_vspace = vspace(args[argnum]) ans_vspace = vspace(ans) jac = numerical_jacobian(fun, argnum, args, kwargs) for outgrad in ans_vspace.examples(): result = fun.vjps[argnum]( outgrad, ans, in_vspace, ans_vspace, *args, **kwargs) result_vspace = vspace(result) result_reals = vspace_flatten(result, True) nd_result_reals = np.dot(vspace_flatten(outgrad, True), jac) assert result_vspace == in_vspace, \ report_mismatch(fun, argnum, args, kwargs, outgrad, in_vspace, result_vspace) assert np.allclose(result_reals, nd_result_reals),\ report_nd_failure(fun, argnum, args, kwargs, outgrad, nd_result_reals, result_reals)
def check_vjp(fun, arg): vs_in = vspace(arg) vs_out = vspace(fun(arg)) autograd_jac = linear_fun_to_matrix( flatten_fun(make_vjp(fun)(arg)[0], vs_out), vs_out).T numerical_jac = linear_fun_to_matrix( numerical_deriv(flatten_fun(fun, vs_in), vspace_flatten(arg)), vs_in) assert np.allclose(autograd_jac, numerical_jac)
def flatten_fun(fun, vs): return lambda x : vspace_flatten(fun(vs.unflatten(x)))