def time_average(data, dim, eps=1e-6, keepdim=True):
    """
    Computes time average across a specified axis.
    """
    mask = cplx.get_mask(data)
    return data.sum(dim,
                    keepdim=keepdim) / (mask.sum(dim, keepdim=keepdim) + eps)
def visualize(args, epoch, model, data_loader, writer, is_training=True):
    def save_image(image, tag):
        image = image.permute(0,3,1,2)
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=1, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            # Load all data arrays
            input, maps, L, R, target, mean, std, norm = data
            input = input.to(args.device)
            maps = maps.to(args.device)
            L = L.to(args.device).squeeze(0)
            R = R.to(args.device).squeeze(0)
            target = target.to(args.device)

            # Data dimensions (for my own reference)
            #  image size:  [batch_size, nx,   ny, nt, nmaps, 2]
            #  kspace size: [batch_size, nkx, nky, nt, ncoils, 2]
            #  maps size:   [batch_size, nkx,  ny,  1, ncoils, nmaps, 2]

            # Compute DL recon
            output, summary_data = model(input, maps, initial_guess=(L, R))

            # Get initial guess
            init = summary_data['init_image']

            # Slice images
            init = init[:,:,:,10,0,None]
            output = output[:,:,:,10,0,None]
            target = target[:,:,:,10,0,None]
            mask = cplx.get_mask(input[:,-1,:,:,0,:]) # [b, y, t, 2]

            # Save images to summary
            tag = 'Train' if is_training else 'Val'
            all_images = torch.cat((init, output, target), dim=2)
            save_image(cplx.abs(all_images), '%s_Images' % tag)
            save_image(cplx.angle(all_images), '%s_Phase' % tag)
            save_image(cplx.abs(output - target), '%s_Error' % tag)
            save_image(mask.permute(0,2,1,3), '%s_Mask' % tag)

            # Save scalars to summary
            for i in range(args.num_grad_steps):
                step_size_L = summary_data['step_size_L_%d' % i]
                writer.add_scalar('step_sizes/L%d' % i, step_size_L.item(), epoch)
                step_size_R = summary_data['step_size_R_%d' % i]
                writer.add_scalar('step_sizes/R%d' % i, step_size_R.item(), epoch)

            break
def visualize(args, epoch, model, data_loader, writer, is_training=True):
    def save_image(image, tag, shape=None):
        image = image.permute(0, 3, 1, 2)
        image -= image.min()
        image /= image.max()
        if shape is not None:
            image = torch.nn.functional.interpolate(image,
                                                    size=shape,
                                                    mode='bilinear',
                                                    align_corners=True)
        grid = torchvision.utils.make_grid(image, nrow=1, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            # Load all data arrays
            input, maps, target, mean, std, norm = data
            input = input.to(args.device)
            maps = maps.to(args.device)
            target = target.to(args.device)

            # Compute zero-filled recon
            A = T.SenseModel(maps)
            zf = A(input, adjoint=True)

            # Compute DL recon
            output = model(input, maps)

            # Slice images [b, y, z, e, 2]
            init = zf[:, :, :, 0, None]
            output = output[:, :, :, 0, None]
            target = target[:, :, :, 0, None]
            mask = cplx.get_mask(input[:, :, :, 0])  # [b, y, t, 2]

            # Save images to summary
            tag = 'Train' if is_training else 'Val'
            all_images = torch.cat((init, output, target), dim=2)
            save_image(cplx.abs(all_images),
                       '%s_Images' % tag,
                       shape=[320, 3 * 320])
            save_image(cplx.angle(all_images),
                       '%s_Phase' % tag,
                       shape=[320, 3 * 320])
            save_image(cplx.abs(output - target),
                       '%s_Error' % tag,
                       shape=[320, 320])
            save_image(mask.permute(0, 2, 1, 3), '%s_Mask' % tag)

            break
Beispiel #4
0
def visualize(args, epoch, model, data_loader, writer, is_training=True):
    def save_image(image, tag):
        image = image.permute(0, 3, 1, 2)
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=1, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            # Load all data arrays
            input, maps, init, target, mean, std, norm = data
            input = input.to(args.device)
            maps = maps.to(args.device)
            init = init.to(args.device)
            target = target.to(args.device)

            # Data dimensions (for my own reference)
            #  image size:  [batch_size, nx,   ny, nt, nmaps, 2]
            #  kspace size: [batch_size, nkx, nky, nt, ncoils, 2]
            #  maps size:   [batch_size, nkx,  ny,  1, ncoils, nmaps, 2]

            # Initialize signal model
            A = T.SenseModel(maps)

            # Compute DL recon
            output = model(input, maps, init_image=init)

            # Slice images
            init = init[:, :, :, 10, 0, None]
            output = output[:, :, :, 10, 0, None]
            target = target[:, :, :, 10, 0, None]
            mask = cplx.get_mask(input[:, -1, :, :, 0, :])  # [b, y, t, 2]

            # Save images to summary
            tag = 'Train' if is_training else 'Val'
            all_images = torch.cat((init, output, target), dim=2)
            save_image(cplx.abs(all_images), '%s_Images' % tag)
            save_image(cplx.angle(all_images), '%s_Phase' % tag)
            save_image(cplx.abs(output - target), '%s_Error' % tag)
            save_image(mask.permute(0, 2, 1, 3), '%s_Mask' % tag)

            break
Beispiel #5
0
    def forward(self, kspace, maps, init_image=None, mask=None):
        """
        Args:
            kspace (torch.Tensor): Input tensor of shape [batch_size, height, width, time, num_coils, 2]
            maps (torch.Tensor): Input tensor of shape   [batch_size, height, width,    1, num_coils, num_emaps, 2]
            mask (torch.Tensor): Input tensor of shape   [batch_size, height, width, time, 1, 1]

        Returns:
            (torch.Tensor): Output tensor of shape       [batch_size, height, width, time, num_emaps, 2]
        """

        if self.num_emaps != maps.size()[-2]:
            raise ValueError(
                'Incorrect number of ESPIRiT maps! Re-prep data...')

        if mask is None:
            mask = cplx.get_mask(kspace)
        kspace *= mask

        # Get data dimensions
        dims = tuple(kspace.size())

        # Declare signal model
        A = SenseModel(maps, weights=mask)

        # Compute zero-filled image reconstruction
        zf_image = A(kspace, adjoint=True)
        image = zf_image if init_image is None else init_image

        # Begin unrolled proximal gradient descent
        for resnet, step_size in zip(self.resnets, self.step_sizes):
            # dc update
            grad_x = A(A(image), adjoint=True) - zf_image
            image = image + step_size * grad_x

            # prox update
            image = image.reshape(dims[0:4] + (self.num_emaps * 2, )).permute(
                0, 4, 3, 2, 1)
            image = resnet(image)
            image = image.permute(0, 4, 3, 2,
                                  1).reshape(dims[0:4] + (self.num_emaps, 2))

        return image
    def forward(self, kspace, maps, initial_guess=None, mask=None):
        """
        Args:
            kspace (torch.Tensor): Input tensor of shape [batch_size, height, width, time, num_coils, 2]
            maps (torch.Tensor): Input tensor of shape   [batch_size, height, width,    1, num_coils, num_emaps, 2]
            mask (torch.Tensor): Input tensor of shape   [batch_size, height, width, time, 1, 1]

        Intermediate variables:
            Spatial basis vectors:  [batch_size, block_size, block_size,    1, num_emaps, num_basis, 2]
            Temporal basis vectors: [batch_size,          1,          1, time,         1, num_basis, 2]

        Returns:
            (torch.Tensor): Output tensor of shape       [batch_size, height, width, time, num_emaps, 2]
        """
        summary_data = {}

        if self.num_emaps != maps.size()[-2]:
            raise ValueError(
                'Incorrect number of ESPIRiT maps! Re-prep data...')
        image_shape = kspace.shape[0:4] + (self.num_emaps, 2)

        if mask is None:
            mask = cplx.get_mask(kspace)

        # Declare linear operators
        A = SenseModel(maps, weights=mask)
        BlockOp = ArrayToBlocks(self.block_size,
                                image_shape,
                                overlapping=self.overlapping)

        # Compute zero-filled image reconstruction
        zf_image = A(kspace, adjoint=True)

        # Get initial guess for L, R basis vectors
        if initial_guess is None:
            L, R = decompose_LR(zf_image, block_op=BlockOp)
        else:
            L, R = initial_guess
        image = self.compose_LR(L, R, BlockOp)

        # save into summary
        summary_data['init_image'] = image

        # Begin unrolled alternating minimization
        for i, (sp_resnet,
                t_resnet) in enumerate(zip(self.sp_resnets, self.t_resnets)):
            # Save previous L,R variables
            L_prev = L
            R_prev = R

            # Compute gradients of ||Y - ALR'||_2 w.r.t. L, R
            grad_x = BlockOp(A(A(image), adjoint=True) -
                             zf_image).unsqueeze(-2)
            L = torch.sum(cplx.mul(grad_x, R_prev), keepdim=True, dim=3)
            R = torch.sum(cplx.mul(cplx.conj(grad_x), L_prev),
                          keepdim=True,
                          dim=(1, 2, 4))

            # L, R model updates
            step_size_L, step_size_R = self.get_step_sizes(L_prev, R_prev)
            L = L_prev + step_size_L * L
            R = R_prev + step_size_R * R

            # L, R network updates
            L, R = self.reshape_LR(L,
                                   L_prev.shape,
                                   R,
                                   R_prev.shape,
                                   beforeNet=True)
            L, R = sp_resnet(L), t_resnet(R)
            L, R = self.reshape_LR(L,
                                   L_prev.shape,
                                   R,
                                   R_prev.shape,
                                   beforeNet=False)

            # Get current image estimate
            image = self.compose_LR(L, R, BlockOp)

            # Save summary variables
            summary_data['image_%d' % i] = image
            summary_data['step_size_L_%d' % i] = step_size_L
            summary_data['step_size_R_%d' % i] = step_size_R

        return image, summary_data