def test_jac_method_grad(): na = 3 params = getfnparams(na) nnparams = getnnparams(na) num_nnparams = len(nnparams) jacs = jac(func2(*nnparams), params) nout = jacs[0].shape[-2] def fcnr(i, v, *allparams): nnparams = allparams[:num_nnparams] params = allparams[num_nnparams:] jacs = jac(func2(*nnparams), params) return jacs[i].rmv(v.view(-1)) def fcnl(i, v, *allparams): nnparams = allparams[:num_nnparams] params = allparams[num_nnparams:] jacs = jac(func2(*nnparams), params) return jacs[i].mv(v.view(-1)) v = torch.rand((na, ), dtype=dtype, requires_grad=True) w = [torch.rand_like(p).requires_grad_() for p in params] for i in range(len(jacs)): gradcheck(fcnr, (i, v, *nnparams, *params)) gradgradcheck(fcnr, (i, v, *nnparams, *params)) gradcheck(fcnl, (i, w[i], *nnparams, *params)) gradgradcheck(fcnl, (i, w[i], *nnparams, *params))
def test_jac_func(): na = 3 params = getfnparams(na) nnparams = getnnparams(na) nparams = len(params) all_idxs = [None, (0, ), (1, ), (0, 1), (0, 1, 2)] funcs = [func1, func2(*nnparams)] for func in funcs: for idxs in all_idxs: if idxs is None: gradparams = params else: gradparams = [params[i] for i in idxs] jacs = jac(func, params, idxs=idxs) assert len(jacs) == len(gradparams) y = func(*params) nout = torch.numel(y) nins = [torch.numel(p) for p in gradparams] v = torch.rand_like(y).requires_grad_() for i in range(len(jacs)): assert list(jacs[i].shape) == [nout, nins[i]] # get rmv jacs_rmv = torch.autograd.grad(y, gradparams, grad_outputs=v, create_graph=True) # the jac LinearOperator has shape of (nout, nin), so we need to flatten v jacs_rmv0 = [jc.rmv(v.view(-1)) for jc in jacs] # calculate the mv w = [torch.rand_like(p) for p in gradparams] jacs_lmv = [ torch.autograd.grad(jacs_rmv[i], (v, ), grad_outputs=w[i], retain_graph=True)[0] for i in range(len(jacs)) ] jacs_lmv0 = [jacs[i].mv(w[i].view(-1)) for i in range(len(jacs))] for i in range(len(jacs)): assert torch.allclose(jacs_rmv[i].view(-1), jacs_rmv0[i].view(-1)) assert torch.allclose(jacs_lmv[i].view(-1), jacs_lmv0[i].view(-1))
def backward(ctx, grad_yout): param_sep = ctx.param_sep yout = ctx.saved_tensors[0] nparams = ctx.nparams fcn = ctx.fcn # merge the tensor and nontensor parameters tensor_params = ctx.saved_tensors[1:] allparams = param_sep.reconstruct_params(tensor_params) params = allparams[:nparams] objparams = allparams[nparams:] # dL/df with ctx.fcn.useobjparams(objparams): jac_dfdy = jac(fcn, params=(yout, *params), idxs=[0])[0] gyfcn = solve(A=jac_dfdy.H, B=-grad_yout.reshape(-1, 1), bck_options=ctx.bck_options, **ctx.bck_options) gyfcn = gyfcn.reshape(grad_yout.shape) # get the grad for the params with torch.enable_grad(): tensor_params_copy = [ p.clone().requires_grad_() for p in tensor_params ] allparams_copy = param_sep.reconstruct_params( tensor_params_copy) params_copy = allparams_copy[:nparams] objparams_copy = allparams_copy[nparams:] with ctx.fcn.useobjparams(objparams_copy): yfcn = fcn(yout, *params_copy) grad_tensor_params = torch.autograd.grad( yfcn, tensor_params_copy, grad_outputs=gyfcn, create_graph=torch.is_grad_enabled(), allow_unused=True) grad_nontensor_params = [ None for _ in range(param_sep.nnontensors()) ] grad_params = param_sep.reconstruct_params(grad_tensor_params, grad_nontensor_params) return (None, None, None, None, None, None, None, *grad_params)
def test_jac_grad(): na = 3 params = getfnparams(na) params2 = [torch.rand(1, dtype=dtype).requires_grad_() for p in params] jacs = jac(func1, params) nout = jacs[0].shape[-2] def fcnr(i, v, *params): jacs = jac(func1, params) return jacs[i].rmv(v.view(-1)) def fcnl(i, w, *params): jacs = jac(func1, params) return jacs[i].mv(w.view(-1)) def fcnr2(i, v, *params2): fmv = get_pure_function(jacs[i].rmv) params0 = v.view(-1) params1 = fmv.objparams() params12 = [p1 * p2 for p1, p2 in zip(params1, params2)] with fmv.useobjparams(params12): return fmv(params0) def fcnl2(i, w, *params2): fmv = get_pure_function(jacs[i].mv) params0 = w.view(-1) params1 = fmv.objparams() params12 = [p1 * p2 for (p1, p2) in zip(params1, params2)] with fmv.useobjparams(params12): return fmv(params0) v = torch.rand((na, ), dtype=dtype, requires_grad=True) w = [torch.rand_like(p).requires_grad_() for p in params] for i in range(len(jacs)): gradcheck(fcnr, (i, v, *params)) gradgradcheck(fcnr, (i, v, *params)) gradcheck(fcnl, (i, w[i], *params)) gradgradcheck(fcnl, (i, w[i], *params)) gradcheck(fcnr2, (i, v, *params2)) gradgradcheck(fcnr2, (i, v, *params2)) gradcheck(fcnl2, (i, w[i], *params2)) gradgradcheck(fcnl2, (i, w[i], *params2))
def fcnl(i, v, *allparams): nnparams = allparams[:num_nnparams] params = allparams[num_nnparams:] jacs = jac(func2(*nnparams), params) return jacs[i].mv(v.view(-1))
def fcnl(i, w, *params): jacs = jac(func1, params) return jacs[i].mv(w.view(-1))
def fcnr(i, v, *params): jacs = jac(func1, params) return jacs[i].rmv(v.view(-1))