Exemplo n.º 1
0
def compute_metrics(args, model, data):
    # Load input, sensitivity maps, and target images onto device
    input, maps, init, target, mean, std, norm = data
    input = input.to(args.device)
    maps = maps.to(args.device)
    init = init.to(args.device)
    target = target.to(args.device)
    mean = mean.to(args.device)
    std = std.to(args.device)

    # Forward pass through network
    output = model(input, maps, init_image=init)

    # Undo normalization from pre-processing
    output = output * std + mean
    target = target * std + mean
    scale = cplx.abs(target).max()

    # Compute image quality metrics from complex-valued images
    cplx_error = cplx.abs(output - target)
    cplx_l1 = torch.mean(cplx_error)
    cplx_l2 = torch.sqrt(torch.mean(cplx_error**2))
    cplx_psnr = 20 * torch.log10(scale / cplx_l2)

    # Compute image quality metrics from magnitude images
    mag_error = torch.abs(cplx.abs(output) - cplx.abs(target))
    mag_l1 = torch.mean(mag_error)
    mag_l2 = torch.sqrt(torch.mean(mag_error**2))
    mag_psnr = 20 * torch.log10(scale / mag_l2)

    return cplx_l1, cplx_l2, cplx_psnr, mag_psnr
def visualize(args, epoch, model, data_loader, writer, is_training=True):
    def save_image(image, tag):
        image = image.permute(0,3,1,2)
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=1, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            # Load all data arrays
            input, maps, L, R, target, mean, std, norm = data
            input = input.to(args.device)
            maps = maps.to(args.device)
            L = L.to(args.device).squeeze(0)
            R = R.to(args.device).squeeze(0)
            target = target.to(args.device)

            # Data dimensions (for my own reference)
            #  image size:  [batch_size, nx,   ny, nt, nmaps, 2]
            #  kspace size: [batch_size, nkx, nky, nt, ncoils, 2]
            #  maps size:   [batch_size, nkx,  ny,  1, ncoils, nmaps, 2]

            # Compute DL recon
            output, summary_data = model(input, maps, initial_guess=(L, R))

            # Get initial guess
            init = summary_data['init_image']

            # Slice images
            init = init[:,:,:,10,0,None]
            output = output[:,:,:,10,0,None]
            target = target[:,:,:,10,0,None]
            mask = cplx.get_mask(input[:,-1,:,:,0,:]) # [b, y, t, 2]

            # Save images to summary
            tag = 'Train' if is_training else 'Val'
            all_images = torch.cat((init, output, target), dim=2)
            save_image(cplx.abs(all_images), '%s_Images' % tag)
            save_image(cplx.angle(all_images), '%s_Phase' % tag)
            save_image(cplx.abs(output - target), '%s_Error' % tag)
            save_image(mask.permute(0,2,1,3), '%s_Mask' % tag)

            # Save scalars to summary
            for i in range(args.num_grad_steps):
                step_size_L = summary_data['step_size_L_%d' % i]
                writer.add_scalar('step_sizes/L%d' % i, step_size_L.item(), epoch)
                step_size_R = summary_data['step_size_R_%d' % i]
                writer.add_scalar('step_sizes/R%d' % i, step_size_R.item(), epoch)

            break
def visualize(args, epoch, model, data_loader, writer, is_training=True):
    def save_image(image, tag, shape=None):
        image = image.permute(0, 3, 1, 2)
        image -= image.min()
        image /= image.max()
        if shape is not None:
            image = torch.nn.functional.interpolate(image,
                                                    size=shape,
                                                    mode='bilinear',
                                                    align_corners=True)
        grid = torchvision.utils.make_grid(image, nrow=1, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            # Load all data arrays
            input, maps, target, mean, std, norm = data
            input = input.to(args.device)
            maps = maps.to(args.device)
            target = target.to(args.device)

            # Compute zero-filled recon
            A = T.SenseModel(maps)
            zf = A(input, adjoint=True)

            # Compute DL recon
            output = model(input, maps)

            # Slice images [b, y, z, e, 2]
            init = zf[:, :, :, 0, None]
            output = output[:, :, :, 0, None]
            target = target[:, :, :, 0, None]
            mask = cplx.get_mask(input[:, :, :, 0])  # [b, y, t, 2]

            # Save images to summary
            tag = 'Train' if is_training else 'Val'
            all_images = torch.cat((init, output, target), dim=2)
            save_image(cplx.abs(all_images),
                       '%s_Images' % tag,
                       shape=[320, 3 * 320])
            save_image(cplx.angle(all_images),
                       '%s_Phase' % tag,
                       shape=[320, 3 * 320])
            save_image(cplx.abs(output - target),
                       '%s_Error' % tag,
                       shape=[320, 320])
            save_image(mask.permute(0, 2, 1, 3), '%s_Mask' % tag)

            break
def preprocess(kspace, maps, args):
    # Batch size dimension must be the same!
    assert kspace.shape[0] == maps.shape[0]
    batch_size = kspace.shape[0]

    # Convert everything from numpy arrays to tensors
    kspace = cplx.to_tensor(kspace)
    maps = cplx.to_tensor(maps)

    # Initialize ESPIRiT model
    A = T.SenseModel(maps)

    # Compute normalization factor (based on 95% max signal level in view-shared dataa)
    averaged_kspace = T.time_average(kspace, dim=3)
    image = A(averaged_kspace, adjoint=True)
    magnitude_vals = cplx.abs(image).reshape(batch_size, -1)
    k = int(round(0.05 * magnitude_vals[0].numel()))
    scale = torch.min(torch.topk(magnitude_vals, k, dim=1).values,
                      dim=1).values

    # Normalize k-space data
    kspace /= scale[:, None, None, None, None, None]

    # Compute network initialization
    if args.slwin_init:
        init_image = A(T.sliding_window(kspace, dim=3, window_size=5),
                       adjoint=True)
    else:
        init_image = A(masked_kspace, adjoint=True)

    return kspace.unsqueeze(1), maps.unsqueeze(1), init_image.unsqueeze(1)
Exemplo n.º 5
0
def visualize(args, epoch, model, data_loader, writer, is_training=True):
    def save_image(image, tag):
        image = image.permute(0, 3, 1, 2)
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=1, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            # Load all data arrays
            input, maps, init, target, mean, std, norm = data
            input = input.to(args.device)
            maps = maps.to(args.device)
            init = init.to(args.device)
            target = target.to(args.device)

            # Data dimensions (for my own reference)
            #  image size:  [batch_size, nx,   ny, nt, nmaps, 2]
            #  kspace size: [batch_size, nkx, nky, nt, ncoils, 2]
            #  maps size:   [batch_size, nkx,  ny,  1, ncoils, nmaps, 2]

            # Initialize signal model
            A = T.SenseModel(maps)

            # Compute DL recon
            output = model(input, maps, init_image=init)

            # Slice images
            init = init[:, :, :, 10, 0, None]
            output = output[:, :, :, 10, 0, None]
            target = target[:, :, :, 10, 0, None]
            mask = cplx.get_mask(input[:, -1, :, :, 0, :])  # [b, y, t, 2]

            # Save images to summary
            tag = 'Train' if is_training else 'Val'
            all_images = torch.cat((init, output, target), dim=2)
            save_image(cplx.abs(all_images), '%s_Images' % tag)
            save_image(cplx.angle(all_images), '%s_Phase' % tag)
            save_image(cplx.abs(output - target), '%s_Error' % tag)
            save_image(mask.permute(0, 2, 1, 3), '%s_Mask' % tag)

            break
def compute_metrics(args, model, data):
    # Load input, sensitivity maps, and target images onto device
    input, maps, target, mean, std, norm = data
    input = input.to(args.device)
    maps = maps.to(args.device)
    target = target.to(args.device)
    mean = mean.to(args.device)
    std = std.to(args.device)
    # Forward pass through network
    output = model(input, maps)
    # Undo normalization from pre-processing
    output = output * std + mean
    target = target * std + mean
    # Compute metrics
    abs_error = cplx.abs(output - target)
    l1 = torch.mean(abs_error)
    l2 = torch.sqrt(torch.mean(abs_error**2))
    psnr = 20 * torch.log10(cplx.abs(target).max() / l2)
    return l1, l2, psnr
Exemplo n.º 7
0
    def __call__(self, kspace, maps, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        seed = None if not self.use_seed else tuple(map(ord, fname))

        # Convert everything from numpy arrays to tensors
        kspace = cplx.to_tensor(kspace).unsqueeze(0)
        maps = cplx.to_tensor(maps).unsqueeze(0)
        target = cplx.to_tensor(target).unsqueeze(0)
        norm = torch.sqrt(torch.mean(cplx.abs(target)**2))

        # Apply random data augmentation
        kspace, target = self.augment(kspace, target, seed)

        # Undersample k-space data
        masked_kspace, mask = ss.subsample(kspace, self.mask_func, seed)

        # Initialize ESPIRiT model
        A = T.SenseModel(maps)

        # Compute normalization factor (based on 95% max signal level in view-shared dataa)
        averaged_kspace = T.time_average(masked_kspace, dim=3)
        image = A(averaged_kspace, adjoint=True)
        magnitude_vals = cplx.abs(image).reshape(-1)
        k = int(round(0.05 * magnitude_vals.numel()))
        scale = torch.min(torch.topk(magnitude_vals, k).values)

        # Normalize k-space and target images
        masked_kspace /= scale
        target /= scale
        mean = torch.tensor([0.0], dtype=torch.float32)
        std = scale

        # Compute network initialization
        if self.slwin_init:
            init_image = A(T.sliding_window(masked_kspace,
                                            dim=3,
                                            window_size=5),
                           adjoint=True)
        else:
            init_image = A(masked_kspace, adjoint=True)

        # Get rid of batch dimension...
        masked_kspace = masked_kspace.squeeze(0)
        maps = maps.squeeze(0)
        init_image = init_image.squeeze(0)
        target = target.squeeze(0)

        return masked_kspace, maps, init_image, target, mean, std, norm
    def __call__(self, kspace, maps, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        # Convert everything from numpy arrays to tensors
        kspace = cplx.to_tensor(kspace).unsqueeze(0)
        maps = cplx.to_tensor(maps).unsqueeze(0)
        target = cplx.to_tensor(target).unsqueeze(0)
        norm = torch.sqrt(torch.mean(cplx.abs(target)**2))

        #print(kspace.shape)
        #print(maps.shape)
        #print(target.shape)

        # Apply mask in k-space
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = ss.subsample(kspace,
                                           self.mask_func,
                                           seed,
                                           mode='2D')

        # Normalize data...
        if 0:
            A = T.SenseModel(maps, weights=mask)
            image = A(masked_kspace, adjoint=True)
            magnitude = cplx.abs(image)
        elif 1:
            # ... by magnitude of zero-filled reconstruction
            A = T.SenseModel(maps)
            image = A(masked_kspace, adjoint=True)
            magnitude_vals = cplx.abs(image).reshape(-1)
            k = int(round(0.05 * magnitude_vals.numel()))
            scale = torch.min(torch.topk(magnitude_vals, k).values)
        else:
            # ... by power within calibration region
            calib_size = 10
            calib_region = cplx.center_crop(masked_kspace,
                                            [calib_size, calib_size])
            scale = torch.mean(cplx.abs(calib_region)**2)
            scale = scale * (calib_size**2 / kspace.size(-3) / kspace.size(-2))

        masked_kspace /= scale
        target /= scale
        mean = torch.tensor([0.0], dtype=torch.float32)
        std = scale

        # Get rid of batch dimension...
        masked_kspace = masked_kspace.squeeze(0)
        maps = maps.squeeze(0)
        target = target.squeeze(0)

        return masked_kspace, maps, target, mean, std, norm
 def save_image(image, tag):
     image = cplx.abs(image).permute(0, 3, 1, 2)  # magnitude
     image -= image.min()
     image /= image.max()
     grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
     writer.add_image(tag, grid, epoch)