Ejemplo n.º 1
0
def visualize(args, epoch, model, data_loader, writer):
    def save_image(image, tag):
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            input, target, mean, std, norm = data
            input = input.to(args.device)
            target = target.unsqueeze(1).to(args.device)

            save_image(target, 'Target')
            if epoch != 0:
                output = model(input.clone())
                # output = transforms.complex_abs(output)  # complex to real
                # output = transforms.root_sum_of_squares(output, dim=1).unsqueeze(1)

                corrupted = model.module.subsampling(input)
                corrupted = corrupted[..., 0]  # complex to real
                cor_all = transforms.root_sum_of_squares(corrupted,dim=1).unsqueeze(1)

                save_image(output, 'Reconstruction')
                save_image(corrupted[:, 0:1, :, :], 'Corrupted0')
                save_image(corrupted[:, 1:2, :, :], 'Corrupted1')
                save_image(cor_all, 'Corrupted')
                save_image(torch.abs(target - output), 'Error')
            break
Ejemplo n.º 2
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.Array): k-space measurements
            target (numpy.Array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object
            fname (pathlib.Path): Path to the input file
            slice (int): Serial number of the slice
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Normalized zero-filled input image
                mean (float): Mean of the zero-filled image
                std (float): Standard deviation of the zero-filled image
                fname (pathlib.Path): Path to the input file
                slice (int): Serial number of the slice
        """
        kspace = transforms.to_tensor(kspace)
        image = transforms.ifft2(kspace)
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)

        image = transforms.complex_abs(image)
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        kspace = transforms.rfft2(image)
        return kspace, mean, std, fname, slice
Ejemplo n.º 3
0
def resize(hparams, image, target):
    smallest_width = min(hparams.resolution, image.shape[-2])
    smallest_height = min(hparams.resolution, image.shape[-3])
    if target is not None:
        smallest_width = min(smallest_width, target.shape[-1])
        smallest_height = min(smallest_height, target.shape[-2])
    crop_size = (smallest_height, smallest_width)
    image = transforms.complex_center_crop(image, crop_size)
    # Absolute value
    image_abs = transforms.complex_abs(image)
    # Apply Root-Sum-of-Squares if multicoil data
    if hparams.challenge == "multicoil":
        image_abs = transforms.root_sum_of_squares(image_abs)
    # Normalize input
    image_abs, mean, std = transforms.normalize_instance(image_abs, eps=1e-11)
    image_abs = image_abs.clamp(-6, 6)

    # Normalize target
    if target is not None:
        target = transforms.to_tensor(target)
        target = transforms.center_crop(target, crop_size)
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
    else:
        target = torch.Tensor([0])
    return image, image_abs, target, mean, std
 def __call__(self, kspace, target, attrs, fname, slice):
     """
     Args:
         kspace (numpy.Array): k-space measurements
         target (numpy.Array): Target image
         attrs (dict): Acquisition related information stored in the HDF5 object
         fname (pathlib.Path): Path to the input file
         slice (int): Serial number of the slice
     Returns:
         (tuple): tuple containing:
             image (torch.Tensor): Normalized zero-filled input image
             mean (float): Mean of the zero-filled image
             std (float): Standard deviation of the zero-filled image
             fname (pathlib.Path): Path to the input file
             slice (int): Serial number of the slice
     """
     kspace = transforms.to_tensor(kspace)
     if self.mask_func is not None:
         seed = tuple(map(ord, fname))
         masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)
     else:
         masked_kspace = kspace
     # Inverse Fourier Transform to get zero filled solution
     image = transforms.ifft2(masked_kspace)
     # Crop input image
     image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
     # Absolute value
     image = transforms.complex_abs(image)
     # Apply Root-Sum-of-Squares if multicoil data
     if self.which_challenge == 'multicoil':
         image = transforms.root_sum_of_squares(image)
     # Normalize input
     image, mean, std = transforms.normalize_instance(image)
     image = image.clamp(-6, 6)
     return image, mean, std, fname, slice
Ejemplo n.º 5
0
 def forward(self, masked_kspace, mask):
     sens_maps = self.sens_net(masked_kspace, mask)
     kspace_pred = masked_kspace.clone()
     for cascade in self.cascades:
         kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps)
     return T.root_sum_of_squares(T.complex_abs(T.ifft2(kspace_pred)),
                                  dim=1)
Ejemplo n.º 6
0
    def __call__(self, kspace, target, challenge, fname, slice_index):
        original_kspace = transforms.to_tensor(kspace)

        if self.reduce:
            original_kspace = reducedimension(original_kspace, self.resolution)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(original_kspace,
                                                    self.mask_func, seed)

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)

        target = transforms.to_tensor(target)
        # Normalize target
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)

        if self.polar:
            original_kspace = cartesianToPolar(original_kspace)
            masked_kspace = cartesianToPolar(masked_kspace)

        return original_kspace, masked_kspace, mask, target, fname, slice_index
Ejemplo n.º 7
0
    def __call__(self, kspace, mask, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            mask (numpy.array): Mask from the test dataset
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
        """

        kspace = transforms.to_tensor(kspace)

        # Apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = transforms.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image to given resolution if larger
        smallest_width = min(self.resolution, image.shape[-2])
        smallest_height = min(self.resolution, image.shape[-3])
        if target is not None:
            smallest_width = min(smallest_width, target.shape[-1])
            smallest_height = min(smallest_height, target.shape[-2])

        crop_size = (smallest_height, smallest_width)
        image = transforms.complex_center_crop(image, crop_size)
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        # Normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])
        return image, target, mean, std, fname, slice
Ejemplo n.º 8
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        if self.use_mask:
            mask = transforms.get_mask(kspace, self.mask_func, seed)
            masked_kspace = mask * kspace
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)

        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        if CLAMP:
            image = image.clamp(-6, 6)

        # Normalize target
        target = transforms.to_tensor(target)
        target_train = transforms.normalize(target, mean, std, eps=1e-11)
        if CLAMP:
            target_train = target_train.clamp(
                -6,
                6)  # Return target (for viz) and target_clamped (for training)

        return image, target_train, mean, std, attrs['norm'].astype(
            np.float32), target
Ejemplo n.º 9
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.Array): k-space measurements
            target (numpy.Array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object
            fname (pathlib.Path): Path to the input file
            slice (int): Serial number of the slice
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Normalized zero-filled input image
                mean (float): Mean of the zero-filled image
                std (float): Standard deviation of the zero-filled image
                fname (pathlib.Path): Path to the input file
                slice (int): Serial number of the slice
        """
        kspace = transforms.to_tensor(kspace)
        if self.mask_func is not None:
            seed = tuple(map(ord, fname))
            masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image)
        image = image.clamp(-6, 6)

        # difference between kspace actual and target dim
        extra = int(masked_kspace.shape[1] - self.kspace_x)

        # clip kspace at input dim
        if extra > 0:
            masked_kspace = masked_kspace[:, (extra//2):-(extra//2), :]

        # zero pad if necessary
        elif extra < 0:
            empty_kspace = torch.zeros((masked_kspace.shape[0], self.kspace_x, masked_kspace.shape[2]))
            empty_kspace[:, -(extra//2):(extra//2), :] = masked_kspace
            masked_kspace = empty_kspace

        #TODO return mask as well for exclusive updates
        return masked_kspace, image, mean, std, fname, slice
Ejemplo n.º 10
0
def kspacetoimage(kspace, args):
    # Inverse Fourier Transform to get zero filled solution
    image = transforms.ifft2(kspace)
    # Crop input image
    image = transforms.complex_center_crop(image,
                                           (args.resolution, args.resolution))
    # Absolute value
    image = transforms.complex_abs(image)
    # Apply Root-Sum-of-Squares if multicoil data
    if args.challenge == 'multicoil':
        image = transforms.root_sum_of_squares(image)
        # Normalize input
    image, mean, std = transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)

    return image
Ejemplo n.º 11
0
def save_zero_filled(data_dir, out_dir, which_challenge, resolution):
    reconstructions = {}

    for file in data_dir.iterdir():
        print("file:{}".format(file))
        with h5py.File(file, "r") as hf:
            masked_kspace = transforms.to_tensor(hf['kspace'][()])
            # Inverse Fourier Transform to get zero filled solution
            image = transforms.ifft2(masked_kspace)
            # Crop input image
            smallest_width = min(resolution, image.shape[-2])
            smallest_height = min(resolution, image.shape[-3])
            image = transforms.complex_center_crop(image, (smallest_height, smallest_width))
            # Absolute value
            image = transforms.complex_abs(image)
            # Apply Root-Sum-of-Squares if multicoil data
            if which_challenge == 'multicoil':
                image = transforms.root_sum_of_squares(image, dim=1)

            reconstructions[file.name] = image
    save_reconstructions(reconstructions, out_dir)
Ejemplo n.º 12
0
def eval(args, model, data_loader):
    model.eval()
    reconstructions = defaultdict(list)
    with torch.no_grad():
        for (input, target, mean, std, norm, fnames, slices) in data_loader:
            input = input.to(args.device)
            # recons = model(input).to('cpu').squeeze(1)

            corrupted = model.module.subsampling(input).to('cpu')
            corrupted = corrupted[..., 0]  # complex to real
            cor_all = transforms.root_sum_of_squares(corrupted, dim=1)

            for i in range(cor_all.shape[0]):
                cor_all[i] = cor_all[i] * std[i] + mean[i]
                reconstructions[fnames[i]].append(
                    (slices[i].numpy(), cor_all[i].numpy()))
            # for i in range(corrupted.shape[0]):
            #     reconstructions[fnames[i]].append((slices[i].numpy(), corrupted[i].numpy()))

    reconstructions = {
        fname: np.stack([pred for _, pred in sorted(slice_preds)])
        for fname, slice_preds in reconstructions.items()
    }
    return reconstructions
Ejemplo n.º 13
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        if self.use_mask:
            mask = transforms.get_mask(kspace, self.mask_func, seed)
            masked_kspace = mask * kspace
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)

        # Normalize input
        if self.normalize:
            image, mean, std = transforms.normalize_instance(image, eps=1e-11)
            if CLAMP:
                image = image.clamp(-6, 6)
        else:
            mean = -1.0
            std = -1.0

        # Normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target_train = target
            if self.normalize:
                target_train = transforms.normalize(target,
                                                    mean,
                                                    std,
                                                    eps=1e-11)
                if CLAMP:
                    target_train = target_train.clamp(
                        -6, 6
                    )  # Return target (for viz) and target_clamped (for training)
            norm = attrs['norm'].astype(np.float32)
        else:
            target_train = []
            target = []
            norm = -1.0
        image_updated = []
        if os.path.exists(
                '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_train/'
                + fname):
            updated_fname = '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_train/' + fname
            with h5py.File(updated_fname, 'r') as data:
                image_updated = data['reconstruction'][slice]
                image_updated = transforms.to_tensor(image_updated)
        elif os.path.exists(
                '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_val/'
                + fname):
            updated_fname = '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_val/' + fname
            with h5py.File(updated_fname, 'r') as data:
                image_updated = data['reconstruction'][slice]
                image_updated = transforms.to_tensor(image_updated)
        elif os.path.exists(
                '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_test/'
                + fname):
            updated_fname = '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_test/' + fname
            with h5py.File(updated_fname, 'r') as data:
                image_updated = data['reconstruction'][slice]
                image_updated = transforms.to_tensor(image_updated)

        return image, target_train, mean, std, norm, target, image_updated
Ejemplo n.º 14
0
        def closure():
            optimizer.zero_grad()
            out = net(net_input.type(dtype))

            # training loss
            if mask_var is not None:
                loss = mse(out * mask_var, img_noisy_var * mask_var)
            elif apply_f:
                loss = mse(apply_f(out), img_noisy_var)
            else:
                loss = mse(out, img_noisy_var)

            loss.backward()
            mse_wrt_noisy[i] = loss.data.cpu().numpy()

            # the actual loss
            true_loss = mse(
                Variable(out.data, requires_grad=False).type(dtype),
                img_clean_var.type(dtype))
            mse_wrt_truth[i] = true_loss.data.cpu().numpy()

            if MRI_multicoil_reference is not None:
                out_chs = net(net_input.type(dtype)).data.cpu().numpy()[0]
                out_imgs = channels2imgs(out_chs)
                out_img_np = transform.root_sum_of_squares(
                    torch.tensor(out_imgs), dim=0).numpy()
                mse_wrt_truth[i] = np.linalg.norm(MRI_multicoil_reference -
                                                  out_img_np)

            if output_gradients:
                for ind, p in enumerate(
                        list(
                            filter(
                                lambda p: p.grad is not None and len(
                                    p.data.shape) > 2, net.parameters()))):
                    out_grads[ind, i] = p.grad.data.norm(2).item()
                    #print(p.grad.data.norm(2).item())
                    #su += p.grad.data.norm(2).item()
                    #mse_wrt_noisy[i] = su

            if i % 10 == 0:
                out2 = net(Variable(net_input_saved).type(dtype))
                loss2 = mse(out2, img_clean_var)
                print(
                    'Iteration %05d    Train loss %f  Actual loss %f Actual loss orig %f'
                    % (i, loss.data, mse_wrt_truth[i], loss2.data),
                    '\r',
                    end='')

            if show_images:
                if i % 50 == 0:
                    print(i)
                    out_img_np = net(ni.type(dtype)).data.cpu().numpy()[0]
                    myimgshow(plt, out_img_np)
                    plt.show()

            if plot_after is not None:
                if i in plot_after:
                    out_imgs[plot_after.index(i), :] = net(
                        net_input_saved.type(dtype)).data.cpu().numpy()[0]

            if output_weights:
                out_weights[:, i] = np.array(
                    get_distances(init_weights, get_weights(net)))

            return loss
Ejemplo n.º 15
0
 def forward(self, input):
     input = self.subsampling(input)
     input = transforms.root_sum_of_squares(transforms.complex_abs(input),
                                            dim=1).unsqueeze(1)
     output = self.reconstruction_model(input)
     return output
Ejemplo n.º 16
0
slice_kspace2 = T.to_tensor(
    slice_kspace)  # Convert from numpy array to pytorch tensor
slice_image = T.ifft2(
    slice_kspace2)  # Apply Inverse Fourier Transform to get the complex image
slice_image_abs = T.complex_abs(
    slice_image)  # Compute absolute value to get a real image

# In[10]:

show_slices(slice_image_abs, [0, 5, 10], cmap='gray')

# As we can see, each slice in a multi-coil MRI scan focusses on a different region of the image. These slices can be combined into the full image using the Root-Sum-of-Squares (RSS) transform.

# In[11]:

slice_image_rss = T.root_sum_of_squares(slice_image_abs, dim=0)

# In[12]:

plt.imshow(np.abs(slice_image_rss.numpy()), cmap='gray')

# So far, we have been looking at fully-sampled data. We can simulate under-sampled data by creating a mask and applying it to k-space.

# In[13]:

from common.subsample import MaskFunc
mask_func = MaskFunc(center_fractions=[0.04],
                     accelerations=[8])  # Create the mask function object

# In[14]:
Ejemplo n.º 17
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        gt = transforms.ifft2(kspace)
        gt = transforms.complex_center_crop(gt, (self.resolution, self.resolution))
        kspace = transforms.fft2(gt)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        masked_kspace = transforms.fft2_nshift(image)
        # Crop input image
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image_mod = transforms.complex_abs(image).max()
        image_r = image[:, :, 0]*6.0/image_mod
        image_i = image[:, :, 1]*6.0/image_mod
        # image_r = image[:, :, 0]
        # image_i = image[:, :, 1]
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input

        image = np.stack((image_r, image_i), axis=-1)
        image = image.transpose((2, 0, 1))
        image = transforms.to_tensor(image)

        target = transforms.ifft2(kspace)
        target = transforms.complex_center_crop(target, (self.resolution, self.resolution))
        # Normalize target
        target_r = target[:, :, 0]*6.0/image_mod
        target_i = target[:, :, 1]*6.0/image_mod
        # target_r = target[:, :, 0]
        # target_i = target[:, :, 1]

        target = np.stack((target_r, target_i), axis=-1)
        target = target.transpose((2, 0, 1))
        target = transforms.to_tensor(target)

        image_mod = np.stack((image_mod, image_mod), axis=0)
        image_mod = transforms.to_tensor(image_mod)

        norm = attrs['norm'].astype(np.float32)
        norm = np.stack((norm, norm), axis=-1)
        norm = transforms.to_tensor(norm)

        mask = mask.expand(kspace.shape)
        mask = mask.transpose(0, 2).transpose(1, 2)
        mask = transforms.ifftshift(mask)

        masked_kspace = masked_kspace.transpose(0, 2).transpose(1, 2)

        return image, target
Ejemplo n.º 18
0
def test_root_sum_of_squares(shape, dim):
    input = create_input(shape)
    out_torch = transforms.root_sum_of_squares(input, dim).numpy()
    out_numpy = np.sqrt(np.sum(input.numpy()**2, dim))
    assert np.allclose(out_torch, out_numpy)
Ejemplo n.º 19
0
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace_rect = transforms.to_tensor(kspace)  ##rectangular kspace

        image_rect = transforms.ifft2(kspace_rect)  ##rectangular FS image
        image_square = transforms.complex_center_crop(
            image_rect,
            (self.resolution, self.resolution))  ##cropped to FS square image

        kspace_square = self.c3object.apply(
            transforms.fft2(image_square)) * 10000  ##kspace of square iamge
        image_square2 = ifft_c3(kspace_square)  ##for training domain_transform

        if self.augmentation:
            kspace_square = self.augmentation.apply(kspace_square)

        # image_square = ifft_c3(kspace_square)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace_square, mask = transforms.apply_mask(
            kspace_square, self.mask_func, seed)  ##ZF square kspace

        # Inverse Fourier Transform to get zero filled solution
        # image = transforms.ifft2(masked_kspace)
        us_image_square = ifft_c3(
            masked_kspace_square)  ## US square complex image

        # Crop input image
        # image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        # image = transforms.complex_abs(image)
        us_image_square_abs = transforms.complex_abs(
            us_image_square)  ## US square real image
        us_image_square_rss = transforms.root_sum_of_squares(
            us_image_square_abs, dim=0)

        stacked_kspace_square = []
        for i in (range(len(kspace_square[:, 0, 0, 0]))):
            stacked_kspace_square.append(kspace_square[i, :, :, 0])
            stacked_kspace_square.append(kspace_square[i, :, :, 1])

        stacked_kspace_square = torch.stack(stacked_kspace_square)

        stacked_masked_kspace_square = []
        # masked_kspace_square = transforms.to_tensor(masked_kspace_square)
        # for i in range(len(masked_kspace_square[:,0,0,0])):
        # stacked_masked_kspace_square.stack(masked_kspace_square[i,:,:,0],masked_kspace_square[i,:,:,1])

        for i in (range(len(masked_kspace_square[:, 0, 0, 0]))):
            stacked_masked_kspace_square.append(masked_kspace_square[i, :, :,
                                                                     0])
            stacked_masked_kspace_square.append(masked_kspace_square[i, :, :,
                                                                     1])

        stacked_masked_kspace_square = torch.stack(
            stacked_masked_kspace_square)

        stacked_image_square = []
        for i in (range(len(image_square[:, 0, 0, 0]))):
            stacked_image_square.append(image_square2[i, :, :, 0])
            stacked_image_square.append(image_square2[i, :, :, 1])

        stacked_image_square = torch.stack(stacked_image_square)




        return stacked_kspace_square,stacked_masked_kspace_square , stacked_image_square , \
            us_image_square_rss ,   \
            target *10000 \
            #mean, std, attrs['norm'].astype(np.float32)
        '''