Esempio n. 1
0
def mse_gradient(x, data):
    """
    Calculates the gradient under a linear forward model
    :param x: image estimate
    :param data: [y,mask], y - zero-filled image reconstruction, mask - sub-sampling mask
    :return: image gradient
    """

    y, mask = data[0], data[1]
    x = real_to_complex(x)
    x = transforms.fft2(x)
    x = mask * x
    x = transforms.ifft2(x)
    x = x - y

    x = complex_to_real(x)
    return x
Esempio n. 2
0
    def __call__(self, kspace, target, attrs, fname, slice, n_slices=-1):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int/list): Serial number(s) of the slice(s). Will be a list for volumes and an int for slices.
            n_slice (int): Number of slices to output.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                mask (torch.Tesnor): k-space sampling mask
                metadata(torch.Tensor): 1-hot vector indicating measurement setup
                fname (pathlib.Path): Path to the input file
                slice (int): Serial number of the slice
        """
        kspace = transforms.to_tensor(kspace)

        if self.mask_func is not None:
            seed = tuple(map(ord, fname))
            masked_kspace, mask = transforms.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace
            mask = kspace != 0
            mask = mask.to(kspace)[..., :1, :, :1]
            mask = mask[:1, ...]

        if self.which_challenge == 'multicoil':
            if masked_kspace.dim() == 5:
                masked_kspace = masked_kspace.transpose(0, 1)
                mask = mask.transpose(0, 1)
        else:
            masked_kspace = masked_kspace.unsqueeze(0)
            mask = mask.unsqueeze(0)

        return transforms.ifft2(masked_kspace), mask, \
               transforms.to_tensor(np.array(attrs['metadata'], np.float32)), \
               fname, slice
Esempio n. 3
0
def get_attack_loss(args,
                    model,
                    ori_target,
                    fnaf_mask,
                    loss_f=torch.nn.MSELoss(reduction='none'),
                    xs=np.random.randint(low=100, high=320 - 100),
                    ys=np.random.randint(low=100, high=320 - 100),
                    shape=(320, 320),
                    n_pixel_range=(10, 11),
                    vis=False):

    ori_target, metadata, target_norm, ori_input = ori_target

    input_o = transforms.complex_abs(ori_input.clone())

    p_max = input_o.max()
    #p_min = (p_max - input.min()) / 2
    #p_min = (p_max - input_o.min())
    p_min = (input_o.min())
    perturb_noise = [
        perturb_noise_init(x=x,
                           y=y,
                           shape=shape,
                           n_pixel_range=n_pixel_range,
                           pixel_value_range=(p_min, p_max))
        for x, y in zip(xs, ys)
    ]
    perturb_noise = np.stack(perturb_noise)

    # perturb the target to get the perturbed image
    #perturb_noise = np.expand_dims(perturb_noise, axis=0)
    #perturb_noise = np.stack((perturb_noise,)*ori_target.shape(0), -1)

    seed = np.random.randint(9999999)

    # normalizer = target_norm
    # for i in range(len(ori_target.size()) - 1):
    #     normalizer = normalizer.unsqueeze(-1)

    #target = ori_target / normalizer
    #print('target: ', target.max(), target.min())

    #perturb_noise = torch.stack([transforms.to_tensor(perturb_noise).unsqueeze(1)]*2, -1)
    perturb_noise = transforms.to_tensor(perturb_noise).unsqueeze(1)

    if not args.fnaf_eval_control:
        input_o += perturb_noise
    target = input_o.clone()

    #print(input_o.shape)
    input_o = np.complex64(input_o.numpy())
    input_o = transforms.to_tensor(input_o)
    input_o = transforms.fft2(input_o)
    input_o, mask = transforms.apply_mask(input_o, fnaf_mask, seed=seed)
    input_o = transforms.ifft2(input_o)

    # apply the perturbed image to the model to get the loss
    #print(input_o.shape)
    output = model.forward(y=input_o, mask=mask, metadata=metadata)
    #output = torch.zeros((8, 1, 368, 368, 2)).to(args.device)
    output = estimate_to_image(output, args.resolution)
    #output = output.reshape(-1, 1, output.size(-2), output.size(-1)).squeeze(1)

    #             output /= normalizer.cuda()
    #output, _, _ = transforms.normalize_instance(output, eps=1e-11)

    #output = transforms.normalize(output, mean, std, eps=1e-11)
    #output = output.clamp(-6, 6)

    #perturb_noise_tensor = transforms.to_tensor(perturb_noise).to(args.device, dtype=torch.double)
    perturb_noise = torch.stack([perturb_noise] * 2, -1)
    perturb_noise = estimate_to_image(perturb_noise, args.resolution).numpy()

    mask = adjusted_mask((perturb_noise != 0))
    #mask = (perturb_noise > 0).astype(np.double)

    mask = transforms.to_tensor(mask).to(args.device)

    #loss = loss_f((output.cpu()*mask_0), (transforms.to_tensor(target)*mask_0))
    target = torch.stack([target] * 2, -1)
    target = estimate_to_image(target, args.resolution).to(args.device)
    #     target /= normalizer
    #target, _, _ = transforms.normalize_instance(target, eps=1e-11)
    #     target = target.clamp(-6, 6)
    #target = transforms.normalize(output, mean, std, eps=1e-11)

    #loss = loss_f(target*mask, output*mask).sum() / torch.sum(mask)
    loss = loss_f(target * mask, output * mask)

    #loss = loss.mean(-1).mean(-1).cpu().numpy()
    #loss = loss.mean(-1).mean(-1).numpy()

    if vis and loss.max() >= 0.001:
        print('vis!')
        print(output.min(), output.max())
        print(target.min(), target.max())

    return loss
Esempio n. 4
0
    def __call__(self, kspace, target, attrs, fname, slice, n_slices=1):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int/list): Serial number(s) of the slice(s). Will be a list for volumes and an int for slices.
            n_slice (int): Number of slices to output.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                mask (torch.Tesnor): k-space sampling mask
                target (torch.Tensor): Target image converted to a torch Tensor.
                metadata(torch.Tensor): 1-hot vector indicating measurement setup
        """
        seed = None if not self.use_seed else tuple(map(ord, fname))
        np.random.seed(seed)

        kspace = transforms.to_tensor(kspace)
        target = transforms.to_tensor(target)
        target = transforms.center_crop(target,
                                        (self.resolution, self.resolution))

        if n_slices > 1:
            total_slices = kspace.size(0)
            slice_id = np.random.randint(0, max(total_slices - n_slices, 1))
            kspace = kspace[slice_id:slice_id + n_slices]
            target = target[slice_id:slice_id + n_slices]
        if n_slices < 0:
            n_slices = kspace.shape[0]

        if self.train_resolution is not None:
            kspace = transforms.ifft2(kspace)
            p = max(
                kspace.size(-3) - self.train_resolution[0],
                kspace.size(-2) - self.train_resolution[1]) // 2 + 1
            kspace = torch.nn.functional.pad(input=kspace,
                                             pad=(0, 0, p, p, p, p),
                                             mode='constant',
                                             value=0)
            kspace = transforms.complex_center_crop(kspace,
                                                    self.train_resolution)
            kspace = transforms.fft2(kspace)

        # Apply mask
        if self.mask_func is not None:
            masked_kspace, mask = transforms.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace
            mask = kspace != 0
            mask = mask.to(kspace)[..., :1, :, :1]
            mask = mask[:1, ...]

        if self.which_challenge == 'multicoil':
            if masked_kspace.dim() == 5:
                masked_kspace = masked_kspace.transpose(0, 1)
                mask = mask.transpose(0, 1)
        else:
            masked_kspace = masked_kspace.unsqueeze(0)
            mask = mask.unsqueeze(0)

        data_norm = attrs['norm'].astype(np.float32) * n_slices**0.5
        return transforms.ifft2(masked_kspace), mask, target, \
               transforms.to_tensor(np.array(attrs['metadata'], np.float32)), \
               data_norm, attrs['max'].astype(np.float32)