Exemplo n.º 1
0
def visualize(args, epoch, model, inference, data_loader, writer):
    
    def save_image(image, tag):
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            output, target = inference(model, data, device=args.device)

            # HACK to make images look good in tensorboard
            output, mean_o, std_o = transforms.normalize_instance(output)
            output = output.clamp(-6, 6)
            output = transforms.unnormalize(output, mean_o, std_o)
            target, mean_t, std_t = transforms.normalize_instance(target)
            target = target.clamp(-6, 6)
            target = transforms.unnormalize(target, mean_t, std_t)

            output = output.unsqueeze(1) # [batch_sz, h, w] --> [batch_sz, 1, h, w]
            target = target.unsqueeze(1) # [batch_sz, h, w] --> [batch_sz, 1, h, w]
            if isinstance(output, dict):
                for k, output_val in output.items():
                    # save_image(input, 'Input_{}'.format(k))
                    save_image(target, 'Target_{}'.format(k))
                    save_image(output, 'Reconstruction_{}'.format(k))
                    save_image(torch.abs(target - output), 'Error_{}'.format(k))
            else:
                # save_image(input, 'Input')
                save_image(target, 'Target')
                save_image(output, 'Reconstruction')
                save_image(torch.abs(target - output), 'Error')
            break
Exemplo n.º 2
0
def evaluate(args):
    args.target_path = f'/home/tomerweiss/Datasets/singlecoil_{args.data_split}'
    args.predictions_path = f'/home/liyon/PILOT/summary/{args.test_name}/rec'
    print('/home/liyon/PILOT/summary/' + args.test_name + '/rec')
    print("args: {}".format(args))
    metrics = Metrics(METRIC_FUNCS)
    for tgt_file in pathlib.Path(args.target_path).iterdir():
        if tgt_file.is_dir():
            continue
        with h5py.File(tgt_file) as target, h5py.File(args.predictions_path +
                                                      '/' +
                                                      tgt_file.name) as recons:
            if args.acquisition and args.acquisition == target.attrs[
                    'acquisition']:
                continue
            recons = recons['reconstruction'].value.squeeze(0)
            target = target['reconstruction_esc'].value

            target = transforms.to_tensor(target[:])
            target = torch.nn.functional.avg_pool2d(target,
                                                    args.resolution_degrading)
            target = center_crop_3d(target, recons.shape)
            target, mean, std = transforms.normalize_instance(target,
                                                              eps=1e-11)
            recons, meanr, stdr = transforms.normalize_instance(recons,
                                                                eps=1e-11)

            target = target.numpy()

            metrics.push(target, recons)

    return metrics
Exemplo n.º 3
0
def mnormalize(masked_kspace):
    #getting image from masked data
    image = transforms.ifft2(masked_kspace)
    #normalizing the image
    nimage, mean, std = transforms.normalize_instance(image, eps=1e-11)
    #getting kspace data from normalized image
    maksed_kspace_fni = transforms.ifftshift(nimage, dim=(-3, -2))
    maksed_kspace_fni = torch.fft(maksed_kspace_fni, 2)
    maksed_kspace_fni = transforms.fftshift(maksed_kspace_fni, dim=(-3, -2))
    maksed_kspace_fni, mean, std = transforms.normalize_instance(masked_kspace,
                                                                 eps=1e-11)
    return maksed_kspace_fni, mean, std
Exemplo n.º 4
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """

        target_inference = transforms.to_tensor(target)

        kspace = transforms.to_tensor(kspace)
        target = transforms.ifft2(kspace)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        if self.use_mask:
            mask = transforms.get_mask(kspace, self.mask_func, seed)
            masked_kspace = mask * kspace
        else:
            masked_kspace = kspace

        image = transforms.ifft2(masked_kspace)
        image_crop = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        _, mean, std = transforms.normalize_instance_complex(image_crop,
                                                             eps=1e-11)

        image_abs = transforms.complex_abs(image_crop)
        image_abs, mean_abs, std_abs = transforms.normalize_instance(image_abs,
                                                                     eps=1e-11)

        image = transforms.normalize(image, mean, std)

        target_image_complex_norm = transforms.normalize(target, mean, std)
        target_kspace_train = transforms.fft2(target_image_complex_norm)

        target = transforms.complex_center_crop(target, (320, 320))
        target = transforms.complex_abs(target)
        target_train = target

        if RENORM:
            target_train = transforms.normalize(target_train, mean_abs,
                                                std_abs)

        if CLAMP:
            image = image.clamp(-6, 6)
            target_train = target_train.clamp(-6, 6)

        return image, target_train, target_kspace_train, mean, std, mask, mean_abs, std_abs, target_inference, attrs[
            'max'], attrs['norm'].astype(np.float32)
Exemplo n.º 5
0
 def __call__(self, k_space, mask, target, attrs, f_name, slice):
     """
     Args:
         kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
             data or (rows, cols, 2) for single coil data.
         mask (numpy.array): Mask from the test dataset
         target (numpy.array): Target image
         attrs (dict): Acquisition related information stored in the HDF5 object.
         fname (str): File name
         slice (int): Serial number of the slice.
     Returns:
         (tuple): tuple containing:
             k_space (torch.Tensor): k-space(resolution x resolution x 2)
             target (torch.Tensor): Target image converted to a torch Tensor.
             fname (str): File name
             slice (int): Serial number of the slice.
     """
     k_space = transforms.to_tensor(k_space)
     full_image = transforms.ifft2(k_space)
     cropped_image = transforms.complex_center_crop(
         full_image, (self.resolution, self.resolution))
     k_space = transforms.fft2(cropped_image)
     # Normalize input
     cropped_image, mean, std = transforms.normalize_instance(cropped_image,
                                                              eps=1e-11)
     cropped_image = cropped_image.clamp(-6, 6)
     # Normalize target
     target = transforms.to_tensor(target)
     target = transforms.center_crop(target,
                                     (self.resolution, self.resolution))
     target = transforms.normalize(target, mean, std, eps=1e-11)
     target = target.clamp(-6, 6)
     return k_space, target, f_name, slice
Exemplo n.º 6
0
    def __call__(self, image, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """

        image = transforms.to_tensor(image)

        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        target = transforms.to_tensor(target)
        # Normalize target
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
        return image, target, mean, std, fname, slice
Exemplo n.º 7
0
    def __call__(self, img, fname, slice):
        image = transforms.to_tensor(img)
        image = transforms.center_crop(image.permute(2,0,1), (self.resolution, self.resolution))
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        return image, mean, std, fname
Exemplo n.º 8
0
    def __call__(self, kspace, target, challenge, fname, slice_index):
        original_kspace = transforms.to_tensor(kspace)

        if self.reduce:
            original_kspace = reducedimension(original_kspace, self.resolution)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(original_kspace,
                                                    self.mask_func, seed)

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)

        target = transforms.to_tensor(target)
        # Normalize target
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)

        if self.polar:
            original_kspace = cartesianToPolar(original_kspace)
            masked_kspace = cartesianToPolar(masked_kspace)

        return original_kspace, masked_kspace, mask, target, fname, slice_index
Exemplo n.º 9
0
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace = transforms.to_tensor(kspace)
        image = transforms.ifft2_regular(kspace)
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # image = transforms.complex_abs(image)
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        # image, mean, std = transforms.normalize_instance_per_channel(image, eps=1e-11)
        # image = image.clamp(-6, 6)
        # kspace = transforms.fft2(image)

        target = transforms.to_tensor(target)
        target, mean, std = transforms.normalize_instance(target, eps=1e-11)
        # # target = transforms.normalize(target, mean, std)
        # target = target.clamp(-6, 6)
        mean = std = 0
        return image, target, mean, std, attrs['norm'].astype(np.float32)
Exemplo n.º 10
0
def resize(hparams, image, target):
    smallest_width = min(hparams.resolution, image.shape[-2])
    smallest_height = min(hparams.resolution, image.shape[-3])
    if target is not None:
        smallest_width = min(smallest_width, target.shape[-1])
        smallest_height = min(smallest_height, target.shape[-2])
    crop_size = (smallest_height, smallest_width)
    image = transforms.complex_center_crop(image, crop_size)
    # Absolute value
    image_abs = transforms.complex_abs(image)
    # Apply Root-Sum-of-Squares if multicoil data
    if hparams.challenge == "multicoil":
        image_abs = transforms.root_sum_of_squares(image_abs)
    # Normalize input
    image_abs, mean, std = transforms.normalize_instance(image_abs, eps=1e-11)
    image_abs = image_abs.clamp(-6, 6)

    # Normalize target
    if target is not None:
        target = transforms.to_tensor(target)
        target = transforms.center_crop(target, crop_size)
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
    else:
        target = torch.Tensor([0])
    return image, image_abs, target, mean, std
Exemplo n.º 11
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.Array): k-space measurements
            target (numpy.Array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object
            fname (pathlib.Path): Path to the input file
            slice (int): Serial number of the slice
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Normalized zero-filled input image
                mean (float): Mean of the zero-filled image
                std (float): Standard deviation of the zero-filled image
                fname (pathlib.Path): Path to the input file
                slice (int): Serial number of the slice
        """
        kspace = transforms.to_tensor(kspace)
        image = transforms.ifft2(kspace)
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)

        image = transforms.complex_abs(image)
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        kspace = transforms.rfft2(image)
        return kspace, mean, std, fname, slice
Exemplo n.º 12
0
 def __call__(self, kspace, target, attrs, fname, slice):
     """
     Args:
         kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
             data or (rows, cols, 2) for single coil data.
         target (numpy.array): Target image
         attrs (dict): Acquisition related information stored in the HDF5 object.
         fname (str): File name
         slice (int): Serial number of the slice.
     Returns:
         (tuple): tuple containing:
             image (torch.Tensor): Zero-filled input image.
             target (torch.Tensor): Target image converted to a torch Tensor.
             mean (float): Mean value used for normalization.
             std (float): Standard deviation value used for normalization.
             norm (float): L2 norm of the entire volume.
     """
     kspace = transforms.to_tensor(kspace)
     # Apply mask
     seed = None if not self.use_seed else tuple(map(ord, fname))
     if self.use_mask:
         mask = transforms.get_mask(kspace, self.mask_func, seed)
         masked_kspace = mask * kspace
     else:
         masked_kspace = kspace
     image = transforms.ifft2(masked_kspace)
     _, mean_image, std_image = transforms.normalize_instance(image, eps=1e-11)
     masked_kspace, mean, std = transforms.normalize_instance_complex(masked_kspace)
     kspace = transforms.normalize(kspace, mean, std)
     
     return masked_kspace, kspace, mean, std, mean_image, std_image, mask
 def __call__(self, kspace, target, attrs, fname, slice):
     """
     Args:
         kspace (numpy.Array): k-space measurements
         target (numpy.Array): Target image
         attrs (dict): Acquisition related information stored in the HDF5 object
         fname (pathlib.Path): Path to the input file
         slice (int): Serial number of the slice
     Returns:
         (tuple): tuple containing:
             image (torch.Tensor): Normalized zero-filled input image
             mean (float): Mean of the zero-filled image
             std (float): Standard deviation of the zero-filled image
             fname (pathlib.Path): Path to the input file
             slice (int): Serial number of the slice
     """
     kspace = transforms.to_tensor(kspace)
     if self.mask_func is not None:
         seed = tuple(map(ord, fname))
         masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)
     else:
         masked_kspace = kspace
     # Inverse Fourier Transform to get zero filled solution
     image = transforms.ifft2(masked_kspace)
     # Crop input image
     image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
     # Absolute value
     image = transforms.complex_abs(image)
     # Apply Root-Sum-of-Squares if multicoil data
     if self.which_challenge == 'multicoil':
         image = transforms.root_sum_of_squares(image)
     # Normalize input
     image, mean, std = transforms.normalize_instance(image)
     image = image.clamp(-6, 6)
     return image, mean, std, fname, slice
Exemplo n.º 14
0
    def __call__(self, image):
        """
        Args:
            image (numpy.array): DICOM image
        Returns:
            image (torch.Tensor): Zero-filled input image.
        """

        # image = np.rot90(image, axes=(0, 1)).copy()
        image = np.flip(image, 0)

        # if image.shape[0] < self.resolution or image.shape[1] < self.resolution:
        #     return None
        # # Crop center
        # image = transforms.center_crop(image, (self.resolution, self.resolution))

        res_crop = min(image.shape[0], image.shape[1])
        image = transforms.center_crop(image, (res_crop, res_crop))
        image = cv2.resize(image, dsize=(self.resolution, self.resolution), interpolation=cv2.INTER_CUBIC)
        
        # Normalize input
        image = transforms.to_tensor(image)
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        return image  
Exemplo n.º 15
0
def evaluate(args):
    args.target_path = f'Datasets/brainT1/{args.data_split}'
    args.predictions_path = f'/summary/{args.test_name}/rec'
    print('summary/' + args.test_name + '/rec')
    print("args: {}".format(args))
    metrics = Metrics(METRIC_FUNCS)
    for tgt_file in pathlib.Path(args.target_path).iterdir():
        if tgt_file.is_dir():
            continue
        with h5py.File(tgt_file) as target, h5py.File(args.predictions_path +
                                                      '/' +
                                                      tgt_file.name) as recons:
            if args.acquisition and args.acquisition == target.attrs[
                    'acquisition']:
                continue
            recons = recons['reconstruction'].value.squeeze(0)
            target = target['data'].value

            target = transforms.to_tensor(target[:])
            if recons.shape[0] == recons.shape[2]:
                # This means we look at the 3d resolution
                min_to_crop = min(target.shape[0], target.shape[1],
                                  target.shape[2])
                min_to_crop -= min_to_crop % 2
                target = pytorch_nufft.transforms.center_crop_3d(
                    target, (min_to_crop, min_to_crop, min_to_crop))
                target = ndimage.zoom(target, recons.shape[0] / min_to_crop)
                target = torch.from_numpy(target).float()
            else:
                target = torch.nn.functional.avg_pool2d(
                    target, args.resolution_degrading)
                args.resolution = min(
                    target.shape[0] // args.resolution_degrading,
                    recons.shape[0])
                target = pytorch_nufft.transforms.center_crop_3d(
                    target, recons.shape)
            target, mean, std = transforms.normalize_instance(target,
                                                              eps=1e-11)
            recons, meanr, stdr = transforms.normalize_instance(recons,
                                                                eps=1e-11)

            target = target.numpy()

            metrics.push(target, recons)

    return metrics
Exemplo n.º 16
0
def test_normalize_instance(shape):
    input = create_input(shape)
    output, mean, stddev = transforms.normalize_instance(input)
    output = output.numpy()
    assert np.isclose(input.numpy().mean(), mean, rtol=1e-2)
    assert np.isclose(input.numpy().std(), stddev, rtol=1e-2)
    assert np.isclose(output.mean(), 0, rtol=1e-2, atol=1e-3)
    assert np.isclose(output.std(), 1, rtol=1e-2, atol=1e-3)
Exemplo n.º 17
0
    def __call__(self, ds_slice, gt_slice, attrs, file_name, s_idx, acc_fac):
        assert gt_slice is None
        with torch.autograd.no_grad():
            ds_slice, mean, std = normalize_instance(to_tensor(ds_slice))
            ds_slice = ds_slice.clamp(min=-6, max=6).unsqueeze(dim=0)

        assert isinstance(file_name, str) and isinstance(s_idx, int), 'Incorrect types!'
        extra_params = dict(mean=mean, std=std, acc_fac=acc_fac, attrs=attrs)
        return ds_slice, file_name, s_idx, extra_params
Exemplo n.º 18
0
    def __call__(self, kspace, mask, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            mask (numpy.array): Mask from the test dataset
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
        """

        kspace = transforms.to_tensor(kspace)

        # Apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = transforms.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image to given resolution if larger
        smallest_width = min(self.resolution, image.shape[-2])
        smallest_height = min(self.resolution, image.shape[-3])
        if target is not None:
            smallest_width = min(smallest_width, target.shape[-1])
            smallest_height = min(smallest_height, target.shape[-2])

        crop_size = (smallest_height, smallest_width)
        image = transforms.complex_center_crop(image, crop_size)
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        # Normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])
        return image, target, mean, std, fname, slice
Exemplo n.º 19
0
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace_rect = transforms.to_tensor(kspace)  ##rectangular kspace

        image_rect = transforms.ifft2(kspace_rect)  ##rectangular FS image
        image_square = transforms.complex_center_crop(
            image_rect,
            (self.resolution, self.resolution))  ##cropped to FS square image
        kspace_square = self.c3object.apply(
            transforms.fft2(image_square))  #* 10000  ##kspace of square iamge

        if self.augmentation:
            kspace_square = self.augmentation.apply(kspace_square)

        image_square = ifft_c3(kspace_square)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace_square, mask = transforms.apply_mask(
            kspace_square, self.mask_func, seed)  ##ZF square kspace

        # Inverse Fourier Transform to get zero filled solution
        # image = transforms.ifft2(masked_kspace)
        image_square_us = ifft_c3(
            masked_kspace_square)  ## US square complex image

        # Crop input image
        # image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        # image = transforms.complex_abs(image)
        image_square_abs = transforms.complex_abs(
            image_square_us)  ## US square real image

        # Apply Root-Sum-of-Squares if multicoil data
        # if self.which_challenge == 'multicoil':
        #     image = transforms.root_sum_of_squares(image)
        # Normalize input
        # image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        _, mean, std = transforms.normalize_instance(image_square_abs,
                                                     eps=1e-11)
        # image = image.clamp(-6, 6)

        # target = transforms.to_tensor(target)
        target = image_square.permute(2, 0, 1)
        # Normalize target
        # target = transforms.normalize(target, mean, std, eps=1e-11)
        # target = target.clamp(-6, 6)
        # return image, target, mean, std, attrs['norm'].astype(np.float32)

        # return masked_kspace_square.permute((2,0,1)), image, image_square.permute(2,0,1), mean, std, attrs['norm'].astype(np.float32)

        # ksp, zf, target, me, st, nor
        return masked_kspace_square.permute((2,0,1)), image_square_us.permute((2,0,1)), \
            target,  \
            mean, std, attrs['norm'].astype(np.float32)
Exemplo n.º 20
0
    def __call__(self, ds_slice, gt_slice, attrs, file_name, s_idx, acc_fac):
        with torch.autograd.no_grad():
            ds_slice, mean, std = normalize_instance(to_tensor(ds_slice))
            gt_slice = normalize(to_tensor(gt_slice), mean, std)

            ds_slice = ds_slice.clamp(min=-6, max=6).unsqueeze(dim=0)
            gt_slice = gt_slice.clamp(min=-6, max=6).unsqueeze(dim=0)

            ds_slice, gt_slice = self.augment_data(ds_slice, gt_slice)

        return ds_slice, gt_slice, 0
Exemplo n.º 21
0
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace = transforms.to_tensor(kspace)
        image = transforms.ifft2_regular(kspace)
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # image = transforms.complex_abs(image)
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        kspace = transforms.fft2(image)

        return image, mean, std, fname, slice
Exemplo n.º 22
0
def train_step(model, data, device):
    input, target, mean, std, norm, _, image_updated = data
    # input, mean, std = transforms.normalize_instance(input, eps=1e-11)
    target, _, _ = transforms.normalize_instance(target, eps=1e-11)
    if CLAMP:
        target = target.clamp(-6, 6)
    if len(image_updated) != 0:
        input = image_updated
    input, mean, std = transforms.normalize_instance(input, eps=1e-11)
    if CLAMP:
        input = input.clamp(-6, 6)
    input = input.unsqueeze(0).unsqueeze(1).to(device)
    target = target.to(device)
    output = model(input).squeeze(1).squeeze(0)

    if SMOOTH:
        loss = F.smooth_l1_loss(output, target)
    else:
        loss = F.l1_loss(output, target)
    return loss
Exemplo n.º 23
0
    def __call__(self, kspace, target, attrs, fname):


        target = transforms.to_tensor(target[:])
        target = torch.nn.functional.avg_pool2d(target,self.resolution_degrading)
        self.resolution=min(320//self.resolution_degrading,self.resolution)
        target = pytorch_nufft.transforms.center_crop_3d(target, (self.depth, self.resolution, self.resolution))
        target, mean, std = transforms.normalize_instance(target, eps=1e-11)
        target = target.clamp(-6, 6)
        kspace = pytorch_nufft.transforms.rfft3_regular(target)

        return kspace, target, mean, std
Exemplo n.º 24
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        if self.use_mask:
            mask = transforms.get_mask(kspace, self.mask_func, seed)
            masked_kspace = mask * kspace
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)

        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        if CLAMP:
            image = image.clamp(-6, 6)

        # Normalize target
        target = transforms.to_tensor(target)
        target_train = transforms.normalize(target, mean, std, eps=1e-11)
        if CLAMP:
            target_train = target_train.clamp(
                -6,
                6)  # Return target (for viz) and target_clamped (for training)

        return image, target_train, mean, std, attrs['norm'].astype(
            np.float32), target
Exemplo n.º 25
0
def data_transform(kspace, mask_function, target, data_attributes, filename,
                   slice_num):
    """
    Perform preprocessing of the kspace image, in order to get a proper input for the net. Should be invoked from
    the SliceData class.
    Args:
        - kspace: complete sampled kspace image
        - mask_func: masking function to apply mask to kspace (TODO not working: we are passing from outside)
        - target: the target image to be reconstructed from the kspace
        - data_attributes: attributes of the whole HDF5 file

    Returns:
        - normalized_masked_image: original kspace with mask applied and cropped to 320 x 320
        - mask: mask generated by masking function
        - normalized_target: normalized target
        - max_value: highest entry in target tensor (for SSIM loss)
    """

    kspace_t = transforms.to_tensor(kspace)
    kspace_t = transforms.normalize_instance(kspace_t)[0]

    masked_kspace, mask = transforms.apply_mask(
        data=kspace_t, mask_func=mask_func
    )  # apply mask: returns masked space and generated mask
    masked_image = fastmri.ifft2c(
        masked_kspace
    )  # Apply Inverse Fourier Transform to get the complex image
    masked_image = transforms.complex_center_crop(
        masked_image, (320, 320))  # center crop masked image
    masked_image = masked_image.permute(
        2, 0, 1)  # permuting the masked image fot pytorch n x c x h x w format
    masked_image = transforms.normalize_instance(masked_image)[0]  # normalize

    target = transforms.to_tensor(target)
    target = transforms.normalize_instance(target)[0]  # normalize
    target = torch.unsqueeze(target, 0)  # add dimension

    return kspace_t, masked_image, target, mask, data_attributes[
        'max'], slice_num
Exemplo n.º 26
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.Array): k-space measurements
            target (numpy.Array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object
            fname (pathlib.Path): Path to the input file
            slice (int): Serial number of the slice
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Normalized zero-filled input image
                mean (float): Mean of the zero-filled image
                std (float): Standard deviation of the zero-filled image
                fname (pathlib.Path): Path to the input file
                slice (int): Serial number of the slice
        """
        kspace = transforms.to_tensor(kspace)
        if self.mask_func is not None:
            seed = tuple(map(ord, fname))
            masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image)
        image = image.clamp(-6, 6)

        # difference between kspace actual and target dim
        extra = int(masked_kspace.shape[1] - self.kspace_x)

        # clip kspace at input dim
        if extra > 0:
            masked_kspace = masked_kspace[:, (extra//2):-(extra//2), :]

        # zero pad if necessary
        elif extra < 0:
            empty_kspace = torch.zeros((masked_kspace.shape[0], self.kspace_x, masked_kspace.shape[2]))
            empty_kspace[:, -(extra//2):(extra//2), :] = masked_kspace
            masked_kspace = empty_kspace

        #TODO return mask as well for exclusive updates
        return masked_kspace, image, mean, std, fname, slice
Exemplo n.º 27
0
def inference(model, data, device):
    input, target, mean, std, norm, unnormalized_target, image_updated = data
    if len(target) != 0:
        target, _, _ = transforms.normalize_instance(target, eps=1e-11)
    if len(image_updated) != 0:
        input = image_updated
    input, mean, std = transforms.normalize_instance(input, eps=1e-11)
    if CLAMP:
        input = input.clamp(-6, 6)

    input = input.unsqueeze(0).unsqueeze(1).to(device)
    if len(unnormalized_target) != 0:
        unnormalized_target = unnormalized_target.to(device)
    output = model(input).squeeze(1).squeeze(0)

    mean = mean.unsqueeze(0).unsqueeze(1).unsqueeze(2).to(device)
    std = std.unsqueeze(0).unsqueeze(1).unsqueeze(2).to(device)
    output = transforms.unnormalize(output, mean, std)
    # if len(target) != 0:
    #     target = transforms.unnormalize(target, mean, std)
    # if len(target) != 0:
    #     target = target * std + mean
    return output, unnormalized_target
Exemplo n.º 28
0
    def to_spatial(self, kspace, resolution):
        '''
        k space: pytorch tensor post enchancement
        '''
        # Inverse Fourier Transform to get interpolated solution
        image = transforms.ifft2(kspace)
        # Crop input image
        image = transforms.complex_center_crop(image, (resolution, resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        return image
Exemplo n.º 29
0
def kspacetoimage(kspace, args):
    # Inverse Fourier Transform to get zero filled solution
    image = transforms.ifft2(kspace)
    # Crop input image
    image = transforms.complex_center_crop(image,
                                           (args.resolution, args.resolution))
    # Absolute value
    image = transforms.complex_abs(image)
    # Apply Root-Sum-of-Squares if multicoil data
    if args.challenge == 'multicoil':
        image = transforms.root_sum_of_squares(image)
        # Normalize input
    image, mean, std = transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)

    return image
Exemplo n.º 30
0
def nkspacetoimage(args, kspace_fni, mean, std, eps=1e-11):
    #nkspace to image
    assert kspace_fni.size(-1) == 2
    image = transforms.ifftshift(kspace_fni, dim=(-3, -2))
    image = torch.ifft(image, 2)
    image = transforms.fftshift(image, dim=(-3, -2))
    #denormalizing the nimage
    image = (image * std) + mean
    image = image[0]

    image = transforms.complex_center_crop(image,
                                           (args.resolution, args.resolution))
    # Absolute value
    image = transforms.complex_abs(image)
    # Normalize input
    image, mean, std = transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)
    return image