def solvefcn(amat, bmat): # is_hermitian=hermit is required to force the hermitian status in numerical gradient alinop = LinearOperator.m(amat, is_hermitian=hermit) x = solve(A=alinop, B=bmat, fwd_options=fwd_options, bck_options=bck_options) return x
def solvefcn(amat, bmat, emat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, E=emat, **fwd_options, bck_options=bck_options) return x
def test_solve_nonsquare_err(dtype, device): torch.manual_seed(seed) mat = torch.rand((3, 2), dtype=dtype, device=device) mat2 = torch.rand((3, 3), dtype=dtype, device=device) linop = LinearOperator.m(mat) linop2 = LinearOperator.m(mat2) B = torch.rand((3, 1), dtype=dtype, device=device) try: res = solve(linop, B) assert False, "A RuntimeError must be raised if the A linear operator in solve not square" except RuntimeError: pass try: res = solve(linop2, B, M=linop) assert False, "A RuntimeError must be raised if the M linear operator in solve is not square" except RuntimeError: pass
def solvefcn(amat, mmat, bmat, emat): mmat = (mmat + mmat.transpose(-2, -1)) * 0.5 alinop = LinearOperator.m(amat) mlinop = LinearOperator.m(mmat) x = solve(A=alinop, B=bmat, E=emat, M=mlinop, **fwd_options, bck_options=bck_options) return x
def backward(ctx, grad_yout): param_sep = ctx.param_sep yout = ctx.saved_tensors[0] nparams = ctx.nparams fcn = ctx.fcn # merge the tensor and nontensor parameters tensor_params = ctx.saved_tensors[1:] allparams = param_sep.reconstruct_params(tensor_params) params = allparams[:nparams] objparams = allparams[nparams:] # dL/df with ctx.fcn.useobjparams(objparams): jac_dfdy = jac(fcn, params=(yout, *params), idxs=[0])[0] gyfcn = solve(A=jac_dfdy.H, B=-grad_yout.reshape(-1, 1), bck_options=ctx.bck_options, **ctx.bck_options) gyfcn = gyfcn.reshape(grad_yout.shape) # get the grad for the params with torch.enable_grad(): tensor_params_copy = [ p.clone().requires_grad_() for p in tensor_params ] allparams_copy = param_sep.reconstruct_params( tensor_params_copy) params_copy = allparams_copy[:nparams] objparams_copy = allparams_copy[nparams:] with ctx.fcn.useobjparams(objparams_copy): yfcn = fcn(yout, *params_copy) grad_tensor_params = torch.autograd.grad( yfcn, tensor_params_copy, grad_outputs=gyfcn, create_graph=torch.is_grad_enabled(), allow_unused=True) grad_nontensor_params = [ None for _ in range(param_sep.nnontensors()) ] grad_params = param_sep.reconstruct_params(grad_tensor_params, grad_nontensor_params) return (None, None, None, None, None, None, None, *grad_params)
def test_solve_mismatch_err(dtype, device): torch.manual_seed(seed) shapes = [ # A B M ([(3, 3), (2, 1), (3, 3)], "the B shape does not match with A"), ([(3, 3), (3, 2), (2, 2)], "the M shape does not match with A"), ] for (ashape, bshape, mshape), msg in shapes: amat = torch.rand(ashape, dtype=dtype, device=device) bmat = torch.rand(bshape, dtype=dtype, device=device) mmat = torch.rand(mshape, dtype=dtype, device=device) + \ torch.eye(mshape[-1], dtype=dtype, device=device) amat = amat + amat.transpose(-2, -1) mmat = mmat + mmat.transpose(-2, -1) alinop = LinearOperator.m(amat) mlinop = LinearOperator.m(mmat) try: res = solve(alinop, B=bmat, M=mlinop) assert False, "A RuntimeError must be raised if %s" % msg except RuntimeError: pass
def backward(ctx, grad_evals, grad_evecs): # grad_evals: (*BAM, neig) # grad_evecs: (*BAM, na, neig) # get the variables from ctx evals = ctx.evals evecs = ctx.evecs M = ctx.M A = ctx.A # the loss function where the gradient will be retrieved # warnings: if not all params have the connection to the output of A, # it could cause an infinite loop because pytorch will keep looking # for the *params node and propagate further backward via the `evecs` # path. So make sure all the *params are all connected in the graph. with torch.enable_grad(): params = [p.clone().requires_grad_() for p in ctx.params] with A.uselinopparams(*params): loss = A.mm(evecs) # (*BAM, na, neig) # calculate the contributions from the eigenvalues gevalsA = grad_evals.unsqueeze(-2) * evecs # (*BAM, na, neig) # calculate the contributions from the eigenvectors with M.uselinopparams( *ctx.mparams) if M is not None else dummy_context_manager(): # orthogonalize the grad_evecs with evecs B = ortho(grad_evecs, evecs, dim=-2, M=M, mright=False) with A.uselinopparams(*ctx.params): gevecs = solve(A, -B, evals, M, fwd_options=ctx.bck_config, bck_options=ctx.bck_config) # orthogonalize gevecs w.r.t. evecs gevecsA = ortho(gevecs, evecs, dim=-2, M=M, mright=True) # accummulate the gradient contributions gaccumA = gevalsA + gevecsA grad_params = torch.autograd.grad( outputs=(loss, ), inputs=params, grad_outputs=(gaccumA, ), create_graph=torch.is_grad_enabled(), ) grad_mparams = [] if ctx.M is not None: with torch.enable_grad(): mparams = [p.clone().requires_grad_() for p in ctx.mparams] with M.uselinopparams(*mparams): mloss = M.mm(evecs) # (*BAM, na, neig) gevalsM = -gevalsA * evals.unsqueeze(-2) gevecsM = -gevecsA * evals.unsqueeze(-2) # the contribution from the parallel elements gevecsM_par = (-0.5 * torch.einsum("...ae,...ae->...e", grad_evecs, evecs) ).unsqueeze(-2) * evecs # (*BAM, na, neig) gaccumM = gevalsM + gevecsM + gevecsM_par grad_mparams = torch.autograd.grad( outputs=(mloss, ), inputs=mparams, grad_outputs=(gaccumM, ), create_graph=torch.is_grad_enabled(), ) return (None, None, None, None, None, None, None, *grad_params, *grad_mparams)
def solvefcn(amat, bmat, emat, mmat): alinop = LinearOperator.m(amat) mlinop = LinearOperator.m(mmat) x = solve(A=alinop, B=bmat, E=emat, M=mlinop, **fwd_options) return x
def solvefcn(amat, bmat): alinop = LinearOperator.m(amat) x = solve(A=alinop, B=bmat, fwd_options=fwd_options) return x
def backward(ctx, grad_evals, grad_evecs): # grad_evals: (*BAM, neig) # grad_evecs: (*BAM, na, neig) # get the variables from ctx evals, evecs = ctx.saved_tensors[:2] na = ctx.na amparams = ctx.saved_tensors[2:] params = amparams[:na] mparams = amparams[na:] M = ctx.M A = ctx.A degen_atol: Optional[float] = ctx.bck_alg_config["degen_atol"] degen_rtol: Optional[float] = ctx.bck_alg_config["degen_rtol"] # set the default values of degen_*tol dtype = evals.dtype if degen_atol is None: degen_atol = torch.finfo(dtype).eps**0.6 if degen_rtol is None: degen_rtol = torch.finfo(dtype).eps**0.4 # check the degeneracy if degen_atol > 0 or degen_rtol > 0: # idx_degen: (*BAM, neig, neig) idx_degen, isdegenerate = _check_degen(evals, degen_atol, degen_rtol) else: isdegenerate = False if not isdegenerate: idx_degen = None # the loss function where the gradient will be retrieved # warnings: if not all params have the connection to the output of A, # it could cause an infinite loop because pytorch will keep looking # for the *params node and propagate further backward via the `evecs` # path. So make sure all the *params are all connected in the graph. with torch.enable_grad(): params = [p.clone().requires_grad_() for p in params] with A.uselinopparams(*params): loss = A.mm(evecs) # (*BAM, na, neig) # if degenerate, check the conditions for finite derivative if is_debug_enabled() and isdegenerate: xtg = torch.matmul(evecs.transpose(-2, -1), grad_evecs) req1 = idx_degen * (xtg - xtg.transpose(-2, -1)) reqtol = xtg.abs().max() * grad_evecs.shape[-2] * torch.finfo( grad_evecs.dtype).eps if not torch.all(torch.abs(req1) <= reqtol): # if the requirements are not satisfied, raises a warning msg = ( "Degeneracy appears but the loss function seem to depend " "strongly on the eigenvector. The gradient might be incorrect.\n" ) msg += "Eigenvalues:\n%s\n" % str(evals) msg += "Degenerate map:\n%s\n" % str(idx_degen) msg += "Requirements (should be all 0s):\n%s" % str(req1) warnings.warn(MathWarning(msg)) # calculate the contributions from the eigenvalues gevalsA = grad_evals.unsqueeze(-2) * evecs # (*BAM, na, neig) # calculate the contributions from the eigenvectors with M.uselinopparams( *mparams) if M is not None else dummy_context_manager(): # orthogonalize the grad_evecs with evecs B = _ortho(grad_evecs, evecs, D=idx_degen, M=M, mright=False) with A.uselinopparams(*params): gevecs = solve(A, -B, evals, M, bck_options=ctx.bck_config, **ctx.bck_config) # (*BAM, na, neig) # orthogonalize gevecs w.r.t. evecs gevecsA = _ortho(gevecs, evecs, D=None, M=M, mright=True) # accummulate the gradient contributions gaccumA = gevalsA + gevecsA grad_params = torch.autograd.grad( outputs=(loss, ), inputs=params, grad_outputs=(gaccumA, ), create_graph=torch.is_grad_enabled(), ) grad_mparams = [] if ctx.M is not None: with torch.enable_grad(): mparams = [p.clone().requires_grad_() for p in mparams] with M.uselinopparams(*mparams): mloss = M.mm(evecs) # (*BAM, na, neig) gevalsM = -gevalsA * evals.unsqueeze(-2) gevecsM = -gevecsA * evals.unsqueeze(-2) # the contribution from the parallel elements gevecsM_par = (-0.5 * torch.einsum("...ae,...ae->...e", grad_evecs, evecs) ).unsqueeze(-2) * evecs # (*BAM, na, neig) gaccumM = gevalsM + gevecsM + gevecsM_par grad_mparams = torch.autograd.grad( outputs=(mloss, ), inputs=mparams, grad_outputs=(gaccumM, ), create_graph=torch.is_grad_enabled(), ) return (None, None, None, None, None, None, None, *grad_params, *grad_mparams)