def llr_compress(images, nb, block_size, overlapping): # Initialize blocking operator block_op = T.ArrayToBlocks(block_size, images.shape, overlapping) # Extract spatial patches across images patches = block_op(images) np = patches.shape[0] # Reshape into batch of 2D matrices patches = patches.permute(0,1,2,4,3,5) patches = patches.reshape((np, ne*block_size**2, nt, 2)) # Perform SVD to get left and right singular vectors U, S, V = cplx.svd(patches, compute_uv=True) # Truncate singular values and corresponding singular vectors U = U[:, :, :nb, :] # [N, Px*Py*E, B, 2] S = S[:, :nb] # [N, B] V = V[:, :, :nb, :] # [N, T, B, 2] # Combine and reshape matrices S_sqrt = S.reshape((np, 1, 1, 1, 1, nb, 1)).sqrt() L = U.reshape((np, block_size, block_size, 1, ne, nb, 2)) * S_sqrt R = V.reshape((np, 1, 1, nt, 1, nb, 2)) * S_sqrt blocks = torch.sum(cplx.mul(L, cplx.conj(R)), dim=-2) images = block_op(blocks, adjoint=True) return images
def glr_compress(images, nb): _, nx, ny, nt, ne, _ = images.shape images = images.permute(0,1,2,4,3,5) images = images.reshape((1, nx*ny*ne, nt, 2)) # Perform SVD to get left and right singular vectors U, S, V = cplx.svd(images, compute_uv=True) # Truncate singular values and corresponding singular vectors U = U[:, :, :nb, :] # [1, Nx*Ny*E, B, 2] S = S[:, :nb] # [1, B] V = V[:, :, :nb, :] # [1, T, B, 2] # Combine and reshape matrices S_sqrt = S.reshape((1, 1, 1, 1, 1, nb, 1)).sqrt() L = U.reshape((1, nx, ny, 1, ne, nb, 2)) * S_sqrt R = V.reshape((1, 1, 1, nt, 1, nb, 2)) * S_sqrt images = torch.sum(cplx.mul(L, cplx.conj(R)), dim=-2) return images
def forward(self, kspace, maps, initial_guess=None, mask=None): """ Args: kspace (torch.Tensor): Input tensor of shape [batch_size, height, width, time, num_coils, 2] maps (torch.Tensor): Input tensor of shape [batch_size, height, width, 1, num_coils, num_emaps, 2] mask (torch.Tensor): Input tensor of shape [batch_size, height, width, time, 1, 1] Intermediate variables: Spatial basis vectors: [batch_size, block_size, block_size, 1, num_emaps, num_basis, 2] Temporal basis vectors: [batch_size, 1, 1, time, 1, num_basis, 2] Returns: (torch.Tensor): Output tensor of shape [batch_size, height, width, time, num_emaps, 2] """ summary_data = {} if self.num_emaps != maps.size()[-2]: raise ValueError( 'Incorrect number of ESPIRiT maps! Re-prep data...') image_shape = kspace.shape[0:4] + (self.num_emaps, 2) if mask is None: mask = cplx.get_mask(kspace) # Declare linear operators A = SenseModel(maps, weights=mask) BlockOp = ArrayToBlocks(self.block_size, image_shape, overlapping=self.overlapping) # Compute zero-filled image reconstruction zf_image = A(kspace, adjoint=True) # Get initial guess for L, R basis vectors if initial_guess is None: L, R = decompose_LR(zf_image, block_op=BlockOp) else: L, R = initial_guess image = self.compose_LR(L, R, BlockOp) # save into summary summary_data['init_image'] = image # Begin unrolled alternating minimization for i, (sp_resnet, t_resnet) in enumerate(zip(self.sp_resnets, self.t_resnets)): # Save previous L,R variables L_prev = L R_prev = R # Compute gradients of ||Y - ALR'||_2 w.r.t. L, R grad_x = BlockOp(A(A(image), adjoint=True) - zf_image).unsqueeze(-2) L = torch.sum(cplx.mul(grad_x, R_prev), keepdim=True, dim=3) R = torch.sum(cplx.mul(cplx.conj(grad_x), L_prev), keepdim=True, dim=(1, 2, 4)) # L, R model updates step_size_L, step_size_R = self.get_step_sizes(L_prev, R_prev) L = L_prev + step_size_L * L R = R_prev + step_size_R * R # L, R network updates L, R = self.reshape_LR(L, L_prev.shape, R, R_prev.shape, beforeNet=True) L, R = sp_resnet(L), t_resnet(R) L, R = self.reshape_LR(L, L_prev.shape, R, R_prev.shape, beforeNet=False) # Get current image estimate image = self.compose_LR(L, R, BlockOp) # Save summary variables summary_data['image_%d' % i] = image summary_data['step_size_L_%d' % i] = step_size_L summary_data['step_size_R_%d' % i] = step_size_R return image, summary_data
def compose_LR(self, L, R, block_op): patches = torch.sum(cplx.mul(L, cplx.conj(R)), dim=-2) return block_op(patches, adjoint=True)
# Extract spatial patches across images patches = block_op(images) np = patches.shape[0] # Reshape into batch of 2D matrices patches = patches.permute(0, 1, 2, 4, 3, 5) patches = patches.reshape((np, ne * blk_size**2, nt, 2)) # Perform SVD to get left and right singular vectors U, S, V = cplx.svd(patches, compute_uv=True) # Truncate singular values and corresponding singular vectors U = U[:, :, :nb, :] # [N, Px*Py*E, B, 2] S = S[:, :nb] # [N, B] V = V[:, :, :nb, :] # [N, T, B, 2] # Combine and reshape matrices S_sqrt = S.reshape((np, 1, 1, 1, 1, nb, 1)).sqrt() L = U.reshape((np, blk_size, blk_size, 1, ne, nb, 2)) * S_sqrt R = V.reshape((np, 1, 1, nt, 1, nb, 2)) * S_sqrt blocks = torch.sum(cplx.mul(L, cplx.conj(R)), dim=-2) images = block_op(blocks, adjoint=True) # Write out images images = cplx.to_numpy(images.squeeze(0)) cfl.writecfl('svdinit_input', orig_images) cfl.writecfl('svdinit_output', images) cfl.writecfl('svdinit_error', orig_images - images)
def _forward_op(self, image): kspace = cplx.mul(image.unsqueeze(-3), self.maps) kspace = self.weights * fft2(kspace.sum(-2)) return kspace
def _adjoint_op(self, kspace): image = ifft2(self.weights * kspace) image = cplx.mul(image.unsqueeze(-2), cplx.conj(self.maps)) return image.sum(-3)