def llr_compress(images, nb, block_size, overlapping):
	# Initialize blocking operator
	block_op = T.ArrayToBlocks(block_size, images.shape, overlapping)

	# Extract spatial patches across images
	patches = block_op(images)
	np = patches.shape[0]

	# Reshape into batch of 2D matrices
	patches = patches.permute(0,1,2,4,3,5)
	patches = patches.reshape((np, ne*block_size**2, nt, 2))

	# Perform SVD to get left and right singular vectors
	U, S, V = cplx.svd(patches, compute_uv=True)

	# Truncate singular values and corresponding singular vectors
	U = U[:, :, :nb, :] # [N, Px*Py*E, B, 2]
	S = S[:, :nb]       # [N, B]
	V = V[:, :, :nb, :] # [N, T, B, 2]

	# Combine and reshape matrices
	S_sqrt = S.reshape((np, 1, 1, 1, 1, nb, 1)).sqrt()
	L = U.reshape((np, block_size, block_size,  1, ne, nb, 2)) * S_sqrt
	R = V.reshape((np,   1,   1, nt,  1, nb, 2)) * S_sqrt
	blocks = torch.sum(cplx.mul(L, cplx.conj(R)), dim=-2)

	images = block_op(blocks, adjoint=True)

	return images
def glr_compress(images, nb):
	_, nx, ny, nt, ne, _ = images.shape

	images = images.permute(0,1,2,4,3,5)
	images = images.reshape((1, nx*ny*ne, nt, 2))

	# Perform SVD to get left and right singular vectors
	U, S, V = cplx.svd(images, compute_uv=True)

	# Truncate singular values and corresponding singular vectors
	U = U[:, :, :nb, :] # [1, Nx*Ny*E, B, 2]
	S = S[:, :nb]       # [1, B]
	V = V[:, :, :nb, :] # [1, T, B, 2]

	# Combine and reshape matrices
	S_sqrt = S.reshape((1, 1, 1, 1, 1, nb, 1)).sqrt()
	L = U.reshape((1, nx, ny,  1, ne, nb, 2)) * S_sqrt
	R = V.reshape((1,  1,  1, nt,  1, nb, 2)) * S_sqrt
	images = torch.sum(cplx.mul(L, cplx.conj(R)), dim=-2)

	return images
    def forward(self, kspace, maps, initial_guess=None, mask=None):
        """
        Args:
            kspace (torch.Tensor): Input tensor of shape [batch_size, height, width, time, num_coils, 2]
            maps (torch.Tensor): Input tensor of shape   [batch_size, height, width,    1, num_coils, num_emaps, 2]
            mask (torch.Tensor): Input tensor of shape   [batch_size, height, width, time, 1, 1]

        Intermediate variables:
            Spatial basis vectors:  [batch_size, block_size, block_size,    1, num_emaps, num_basis, 2]
            Temporal basis vectors: [batch_size,          1,          1, time,         1, num_basis, 2]

        Returns:
            (torch.Tensor): Output tensor of shape       [batch_size, height, width, time, num_emaps, 2]
        """
        summary_data = {}

        if self.num_emaps != maps.size()[-2]:
            raise ValueError(
                'Incorrect number of ESPIRiT maps! Re-prep data...')
        image_shape = kspace.shape[0:4] + (self.num_emaps, 2)

        if mask is None:
            mask = cplx.get_mask(kspace)

        # Declare linear operators
        A = SenseModel(maps, weights=mask)
        BlockOp = ArrayToBlocks(self.block_size,
                                image_shape,
                                overlapping=self.overlapping)

        # Compute zero-filled image reconstruction
        zf_image = A(kspace, adjoint=True)

        # Get initial guess for L, R basis vectors
        if initial_guess is None:
            L, R = decompose_LR(zf_image, block_op=BlockOp)
        else:
            L, R = initial_guess
        image = self.compose_LR(L, R, BlockOp)

        # save into summary
        summary_data['init_image'] = image

        # Begin unrolled alternating minimization
        for i, (sp_resnet,
                t_resnet) in enumerate(zip(self.sp_resnets, self.t_resnets)):
            # Save previous L,R variables
            L_prev = L
            R_prev = R

            # Compute gradients of ||Y - ALR'||_2 w.r.t. L, R
            grad_x = BlockOp(A(A(image), adjoint=True) -
                             zf_image).unsqueeze(-2)
            L = torch.sum(cplx.mul(grad_x, R_prev), keepdim=True, dim=3)
            R = torch.sum(cplx.mul(cplx.conj(grad_x), L_prev),
                          keepdim=True,
                          dim=(1, 2, 4))

            # L, R model updates
            step_size_L, step_size_R = self.get_step_sizes(L_prev, R_prev)
            L = L_prev + step_size_L * L
            R = R_prev + step_size_R * R

            # L, R network updates
            L, R = self.reshape_LR(L,
                                   L_prev.shape,
                                   R,
                                   R_prev.shape,
                                   beforeNet=True)
            L, R = sp_resnet(L), t_resnet(R)
            L, R = self.reshape_LR(L,
                                   L_prev.shape,
                                   R,
                                   R_prev.shape,
                                   beforeNet=False)

            # Get current image estimate
            image = self.compose_LR(L, R, BlockOp)

            # Save summary variables
            summary_data['image_%d' % i] = image
            summary_data['step_size_L_%d' % i] = step_size_L
            summary_data['step_size_R_%d' % i] = step_size_R

        return image, summary_data
 def compose_LR(self, L, R, block_op):
     patches = torch.sum(cplx.mul(L, cplx.conj(R)), dim=-2)
     return block_op(patches, adjoint=True)
# Extract spatial patches across images
patches = block_op(images)
np = patches.shape[0]

# Reshape into batch of 2D matrices
patches = patches.permute(0, 1, 2, 4, 3, 5)
patches = patches.reshape((np, ne * blk_size**2, nt, 2))

# Perform SVD to get left and right singular vectors
U, S, V = cplx.svd(patches, compute_uv=True)

# Truncate singular values and corresponding singular vectors
U = U[:, :, :nb, :]  # [N, Px*Py*E, B, 2]
S = S[:, :nb]  # [N, B]
V = V[:, :, :nb, :]  # [N, T, B, 2]

# Combine and reshape matrices
S_sqrt = S.reshape((np, 1, 1, 1, 1, nb, 1)).sqrt()
L = U.reshape((np, blk_size, blk_size, 1, ne, nb, 2)) * S_sqrt
R = V.reshape((np, 1, 1, nt, 1, nb, 2)) * S_sqrt
blocks = torch.sum(cplx.mul(L, cplx.conj(R)), dim=-2)

images = block_op(blocks, adjoint=True)

# Write out images
images = cplx.to_numpy(images.squeeze(0))
cfl.writecfl('svdinit_input', orig_images)
cfl.writecfl('svdinit_output', images)
cfl.writecfl('svdinit_error', orig_images - images)
 def _forward_op(self, image):
     kspace = cplx.mul(image.unsqueeze(-3), self.maps)
     kspace = self.weights * fft2(kspace.sum(-2))
     return kspace
 def _adjoint_op(self, kspace):
     image = ifft2(self.weights * kspace)
     image = cplx.mul(image.unsqueeze(-2), cplx.conj(self.maps))
     return image.sum(-3)