def _solve_ABE(A: torch.Tensor, B: torch.Tensor, E: torch.Tensor): # A: (*BA, na, na) matrix # B: (*BB, na, ncols) matrix # E: (*BE, ncols) matrix na = A.shape[-1] BA, BB, BE = normalize_bcast_dims(A.shape[:-2], B.shape[:-2], E.shape[:-1]) E = E.view(1, *BE, E.shape[-1]).transpose(0, -1) # (ncols, *BE, 1) B = B.view(1, *BB, *B.shape[-2:]).transpose(0, -1) # (ncols, *BB, na, 1) # NOTE: The line below is very inefficient for large na and ncols AE = A - torch.diag_embed(E.repeat_interleave(repeats=na, dim=-1), dim1=-2, dim2=-1) # (ncols, *BAE, na, na) r, _ = torch.solve(B, AE) # (ncols, *BAEM, na, 1) r = r.transpose(0, -1).squeeze(0) # (*BAEM, na, ncols) return r
def _setup_linear_problem(A: LinearOperator, B: torch.Tensor, E: Optional[torch.Tensor], M: Optional[LinearOperator], batchdims: Sequence[int], posdef: Optional[bool], need_hermit: bool) -> \ Tuple[Callable[[torch.Tensor], torch.Tensor], Callable[[torch.Tensor], torch.Tensor], torch.Tensor, bool]: # get the linear operator (including the MXE part) if E is None: A_fcn = lambda x: A.mm(x) AT_fcn = lambda x: A.rmm(x) B_new = B col_swapped = False else: # A: (*BA, nr, nr) linop # B: (*BB, nr, ncols) # E: (*BE, ncols) # M: (*BM, nr, nr) linop if M is None: BAs, BBs, BEs = normalize_bcast_dims(A.shape[:-2], B.shape[:-2], E.shape[:-1]) else: BAs, BBs, BEs, BMs = normalize_bcast_dims(A.shape[:-2], B.shape[:-2], E.shape[:-1], M.shape[:-2]) E = E.reshape(*BEs, *E.shape[-1:]) E_new = E.unsqueeze(0).transpose(-1, 0).unsqueeze( -1) # (ncols, *BEs, 1, 1) B = B.reshape(*BBs, *B.shape[-2:]) # (*BBs, nr, ncols) B_new = B.unsqueeze(0).transpose(-1, 0) # (ncols, *BBs, nr, 1) def A_fcn(x): # x: (ncols, *BX, nr, 1) Ax = A.mm(x) # (ncols, *BAX, nr, 1) Mx = M.mm(x) if M is not None else x # (ncols, *BMX, nr, 1) MxE = Mx * E_new # (ncols, *BMXE, nr, 1) return Ax - MxE def AT_fcn(x): # x: (ncols, *BX, nr, 1) ATx = A.rmm(x) MTx = M.rmm(x) if M is not None else x MTxE = MTx * E_new return ATx - MTxE col_swapped = True # estimate if it's posdef with power iteration if need_hermit: is_hermit = A.is_hermitian and (M is None or M.is_hermitian) if not is_hermit: posdef = False if posdef is None: nr, ncols = B.shape[-2:] x0shape = (ncols, *batchdims, nr, 1) if col_swapped else (*batchdims, nr, ncols) x0 = torch.randn(x0shape, dtype=A.dtype, device=A.device) x0 = x0 / x0.norm(dim=-2, keepdim=True) largest_eival = _get_largest_eival(A_fcn, x0) # (*, 1, nc) negeival = largest_eival <= 0 # if the largest eigenvalue is negative, then it's not posdef if torch.all(negeival): posdef = False else: offset = torch.clamp(largest_eival, min=0.0) A_fcn2 = lambda x: A_fcn(x) - offset * x mostneg_eival = _get_largest_eival(A_fcn2, x0) # (*, 1, nc) posdef = bool( torch.all(torch.logical_or(-mostneg_eival <= offset, negeival)).item()) # get the linear operation if it is not a posdef (A -> AT.A) if posdef: return A_fcn, AT_fcn, B_new, col_swapped else: def A_new_fcn(x): return AT_fcn(A_fcn(x)) B2 = AT_fcn(B_new) return A_new_fcn, A_new_fcn, B2, col_swapped