def hessian_vector_product(fun, argnum=0): """Builds a function that returns the exact Hessian-vector product. The returned function has arguments (*args, vector, **kwargs), and takes roughly 4x as long to evaluate as the original function.""" fun_grad = grad(fun, argnum) def vector_dot_grad(*args, **kwargs): args, vector = args[:-1], args[-1] return np.dot(vector, fun_grad(*args, **kwargs)) return grad(vector_dot_grad, argnum) # Grad wrt original input.
def grad_and_aux_fun(*args, **kwargs): saved = lambda: None def return_val_save_aux(*args, **kwargs): val, saved.aux = fun(*args, **kwargs) return val gradval = grad(return_val_save_aux, argnum)(*args, **kwargs) return gradval, saved.aux
def elementwise_grad(fun, argnum=0): """Like `jacobian`, but produces a function which computes just the diagonal of the Jacobian, and does the computation in one pass rather than in a loop. Note: this is only valid if the Jacobian is diagonal. Only arrays are currently supported. Can be used for broadcasting.""" def sum_output(*args, **kwargs): return np.sum(fun(*args, **kwargs)) return grad(sum_output, argnum=argnum)
def multigrad(fun, argnums=[0]): """Takes gradients wrt multiple arguments simultaneously.""" def combined_arg_fun(multi_arg, *args, **kwargs): extra_args_list = list(args) for argnum_ix, arg_ix in enumerate(argnums): extra_args_list[arg_ix] = multi_arg[argnum_ix] return fun(*extra_args_list, **kwargs) gradfun = grad(combined_arg_fun, argnum=0) def gradfun_rearranged(*args, **kwargs): multi_arg = tuple([args[i] for i in argnums]) return gradfun(multi_arg, *args, **kwargs) return gradfun_rearranged
def gradfun(*args, **kwargs): bindings = sig.bind(*args, **kwargs) args = lambda dct: tuple(dct[var_pos[0]]) if var_pos else () kwargs = lambda dct: todict(dct[var_kwd[0]]) if var_kwd else {} others = lambda dct: tuple(dct[argname] for argname in argnames if argname not in var_kwd + var_pos) newfun = lambda dct: fun(*(others(dct) + args(dct)), **kwargs(dct)) argdict = apply_defaults(bindings.arguments) grad_dict = grad(newfun)(dict(argdict)) return OrderedDict((argname, grad_dict[argname]) for argname in argdict)
def quick_grad_check(fun, arg0, extra_args=(), kwargs={}, verbose=True, eps=EPS, rtol=RTOL, atol=ATOL, rs=None): """Checks the gradient of a function (w.r.t. to its first arg) in a random direction""" if verbose: print("Checking gradient of {0} at {1}".format(fun, arg0)) if rs is None: rs = np.random.RandomState() random_dir = rs.standard_normal(np.shape(arg0)) random_dir = random_dir / np.sqrt(np.sum(random_dir * random_dir)) unary_fun = lambda x : fun(arg0 + x * random_dir, *extra_args, **kwargs) numeric_grad = unary_nd(unary_fun, 0.0, eps=eps) analytic_grad = np.sum(grad(fun)(arg0, *extra_args, **kwargs) * random_dir) assert np.allclose(numeric_grad, analytic_grad, rtol=rtol, atol=atol), \ "Check failed! nd={0}, ad={1}".format(numeric_grad, analytic_grad) if verbose: print("Gradient projection OK (numeric grad: {0}, analytic grad: {1})".format( numeric_grad, analytic_grad))
# The reason for the closure is so that the gradient can depend # on both the input to the original function (x), and the output of the # original function (ans). def make_grad_logsumexp(ans, x): # If you want to be able to take higher-order derivatives, then all the # code inside this function must be itself differentiable by autogradwithbay. def gradient_product(g): # This closure multiplies g with the Jacobian of logsumexp (d_ans/d_x). # Because autogradwithbay uses reverse-mode differentiation, g contains # the gradient of the objective w.r.t. ans, the output of logsumexp. return np.full(x.shape, g) * np.exp(x - np.full(x.shape, ans)) return gradient_product # Now we tell autogradwithbay that logsumexmp has a gradient-making function. logsumexp.defgrad(make_grad_logsumexp) if __name__ == '__main__': # Now we can use logsumexp() inside a larger function that we want # to differentiate. def example_func(y): z = y**2 lse = logsumexp(z) return np.sum(lse) grad_of_example = grad(example_func) print("Gradient: ", grad_of_example(npr.randn(10))) # Check the gradients numerically, just to be safe. quick_grad_check(example_func, npr.randn(10))
def check_grads(fun, *args): if not args: raise Exception("No args given") exact = tuple([grad(fun, i)(*args) for i in range(len(args))]) numeric = nd(fun, *args) check_equivalent(exact, numeric)