Exemplo n.º 1
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        target = transforms.ifft2(kspace)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        if self.use_mask:
            mask = transforms.get_mask(kspace, self.mask_func, seed)
            masked_kspace = mask * kspace
        else:
            masked_kspace = kspace
        image = transforms.ifft2(masked_kspace)

        image_abs = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        image_abs = transforms.complex_abs(image_abs)
        image_abs, mean_abs, std_abs = transforms.normalize_instance(image_abs,
                                                                     eps=1e-11)

        image, mean, std = transforms.normalize_instance_complex(image,
                                                                 eps=1e-11)

        target = transforms.complex_center_crop(target, (320, 320))
        target = transforms.complex_abs(target)
        target_train = target

        if RENORM:
            target_train = transforms.normalize(target_train, mean_abs,
                                                std_abs)

        if CLAMP:
            # image = image.clamp(-6, 6)
            target_train = target_train.clamp(-6, 6)

        return image, target_train, mean, std, mask, mean_abs, std_abs, target, attrs[
            'max'], attrs['norm'].astype(np.float32)
Exemplo n.º 2
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
        """
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = transforms.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image to given resolution if larger
        smallest_width = min(self.resolution, image.shape[-2])
        smallest_height = min(self.resolution, image.shape[-3])
        if target is not None:
            smallest_width = min(smallest_width, target.shape[-1])
            smallest_height = min(smallest_height, target.shape[-2])
        crop_size = (smallest_height, smallest_width)
        image = transforms.complex_center_crop(image, crop_size)
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        # Normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])
        return image, target, mean, std, fname, slice
Exemplo n.º 3
0
def train_step(model, data, device):
    _, target, _, _, _, mean_abs, std_abs, _ = data
    target = target.to(device)

    output = generate(model, data, device)
    if RENORM:
        mean_abs = mean_abs.unsqueeze(1).unsqueeze(2).to(device)
        std_abs = std_abs.unsqueeze(1).unsqueeze(2).to(device)
        output = transforms.normalize(output, mean_abs, std_abs)

    if SMOOTH:
        loss_f = F.smooth_l1_loss
    else:
        loss_f = F.l1_loss

    return loss_f(output, target)
Exemplo n.º 4
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        mask = transforms.get_mask(kspace, self.mask_func, seed)
        masked_kspace = mask * kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)

        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        target = transforms.to_tensor(target)

        # Normalize target
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target_clamped = target.clamp(
            -6, 6)  # Return target (for viz) and target_clamped (for training)
        return image, target_clamped, mean, std, attrs['norm'].astype(
            np.float32), target
Exemplo n.º 5
0
    def inference(self, model, data, device):
        model.unc_model.eval()
        model.r_model.eval()
        input, _, mean, std, _, target, _ = data
        with torch.no_grad():
            output, target = self.module.inference(model.r_model, data[:-1],
                                                   device)
        # Renormalize
        mean = mean.unsqueeze(1).unsqueeze(2).to(device)
        std = std.unsqueeze(1).unsqueeze(2).to(device)
        output_normalized = transforms.normalize(output, mean, std, eps=1e-11)
        loss_prediction = model.unc_model(output_normalized.unsqueeze(1))

        if self.loss == 'l1':
            confidence = -1 * loss_prediction
        elif self.loss == 'ssim':
            confidence = loss_prediction
        else:
            raise ValueError('Invalid loss')

        return output, target, confidence.squeeze(1)
Exemplo n.º 6
0
def train_step_generator(generator, discriminator, data, device):

    _, target, _, _, _, mean_abs, std_abs, _ = data
    target = target.to(device)

    output = generate(generator, data, device)
    if RENORM:
        mean_abs = mean_abs.unsqueeze(1).unsqueeze(2).to(device)
        std_abs = std_abs.unsqueeze(1).unsqueeze(2).to(device)
        output = transforms.normalize(output, mean_abs, std_abs)
    consistency_loss = F.l1_loss(output, target)

    p_output = discriminator(output.unsqueeze(1))
    if WGAN:
        disc_loss = -torch.mean(p_output)
    else:
        disc_loss = F.binary_cross_entropy_with_logits(
            p_output,
            torch.ones(p_output.shape).to(device))

    return 0.01 * disc_loss, consistency_loss
Exemplo n.º 7
0
def train_step(model, data, device):
    _, target, target_kspace, _, _, _, mean_abs, std_abs, _, data_range, norm = data
    target = target.to(device)
    target_kspace = target_kspace.to(device)
    data_range = data_range.float().to(device)

    output_consistent, output_network, output_kspace = generate(
        model, data, device)
    if RENORM:
        mean_abs = mean_abs.unsqueeze(1).unsqueeze(2).to(device)
        std_abs = std_abs.unsqueeze(1).unsqueeze(2).to(device)
        output_consistent = transforms.normalize(output_consistent, mean_abs,
                                                 std_abs)

    if SMOOTH:
        loss_f = F.smooth_l1_loss
    else:
        loss_f = F.l1_loss

    if DIVIDE_NORM:
        norm = norm.unsqueeze(1).unsqueeze(2).to(device)
        consistent_loss = loss_f(output_consistent / norm, target / norm)
    else:
        consistent_loss = loss_f(output_consistent, target)

    k_loss = loss_f(output_kspace, target_kspace)

    loss = consistent_loss + MULTI * k_loss

    SSIMLoss = pytorch_ssim.SSIM(window_size=7)
    ssim_loss = SSIMLoss(output_consistent.unsqueeze(1), target.unsqueeze(1),
                         data_range)
    if ssim_loss.item() < SSIM_THRES:
        loss -= SSIM_WEIGHT * ssim_loss

    if RENORM:
        return loss
    else:
        return 1e8 * loss
Exemplo n.º 8
0
def test_normalize(shape, mean, stddev):
    input = create_input(shape)
    output = transforms.normalize(input, mean, stddev).numpy()
    assert np.isclose(output.mean(), (input.numpy().mean() - mean) / stddev)
    assert np.isclose(output.std(), input.numpy().std() / stddev)
Exemplo n.º 9
0
def get_attack_loss_new(model, ori_target, loss_f=torch.nn.MSELoss(reduction='none'), 
    xs=np.random.randint(low=100, high=320-100, size=(16,)), 
    ys=np.random.randint(low=100, high=320-100, size=(16,)), 
    shape=(320, 320), n_pixel_range=(10, 11), train=False, optimizer=None):
    
    input_o = ori_target.unsqueeze(1).to(args.device)
    input_o = input_o.clone()
    
    #input_o = transforms.complex_abs(ori_input.clone())
    #input_o, mean, std = transforms.normalize_instance(ori_target.unsqueeze(1).clone())
    #input_o = torch.clamp(input_o, -6, 6)

    #perturb_noise = perturb_noise_init(x=x, y=y, shape=shape, n_pixel_range=n_pixel_range)
    p_max = input_o.max().cpu()
    #p_min = (p_max - input.min()) / 2
    #p_min = (p_max - input_o.min())
    p_min = input_o.min().cpu()
    perturb_noise = [perturb_noise_init(x=x, y=y, shape=shape, n_pixel_range=n_pixel_range, pixel_value_range=(p_min, p_max)) for x, y in zip(xs, ys)]
    perturb_noise = np.stack(perturb_noise)
            
    # perturb the target to get the perturbed image
    #perturb_noise = np.expand_dims(perturb_noise, axis=0)
    #perturb_noise = np.stack((perturb_noise,)*ori_target.shape(0), -1)

    seed = np.random.randint(999999999)
    
    
    perturb_noise = transforms.to_tensor(perturb_noise).unsqueeze(1).to(args.device)
    
    if not args.fnaf_eval_control:
        input_o += perturb_noise
    target = input_o.clone()
    
    #print(input_o.shape)
    input_o = np.complex64(input_o.cpu().numpy())
    input_o = transforms.to_tensor(input_o)
    input_o = transforms.fft2(input_o)
    input_o, mask = transforms.apply_mask(input_o, mask_f, seed)
    input_o = transforms.ifft2(input_o)
    
    image = transforms.complex_abs(input_o).to(args.device)
    image, mean, std = transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)
    
    target = transforms.normalize(target, mean, std, eps=1e-11)
    target = target.clamp(-6, 6)

    #information_loss = loss_f(og_image.squeeze(1), image.squeeze(1)).mean(-1).mean(-1).cpu().numpy()
    #information_loss = np.array([0]*len(xs))

    # apply the perturbed image to the model to get the loss
    if train:
        output = model(image).squeeze(1)
    else:
        with torch.no_grad():
            output = model(image).squeeze(1)
            
    #perturb_noise_tensor = transforms.to_tensor(perturb_noise).to(args.device, dtype=torch.double)
    perturb_noise = perturb_noise.squeeze(1)
    perturb_noise_tensor = perturb_noise
    
    perturb_noise = perturb_noise.cpu().numpy()
        
    mask = adjusted_mask((perturb_noise > 0).astype(np.double))
    #mask = (perturb_noise > 0).astype(np.double)
    

        
    target = target.squeeze(1)
    mask_0 = transforms.to_tensor(mask).to(args.device)

    loss = loss_f((output*mask_0), (target*mask_0))

    if train:
        b_loss = loss.sum() / mask_0.sum() * 1 + loss_f(output, target).mean()
        b_loss.backward()
        optimizer.step()
        loss = loss.detach()

        loss = loss.mean(-1).mean(-1).cpu().numpy()
    #loss = loss.mean(-1).mean(-1).numpy()

    # information_loss_list.append(information_loss)
    # xs_list.append(xs)
    # ys_list.append(ys)
    
    
    return loss
Exemplo n.º 10
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        if self.use_mask:
            mask = transforms.get_mask(kspace, self.mask_func, seed)
            masked_kspace = mask * kspace
        else:
            masked_kspace = kspace

        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)

        # Normalize input
        if self.normalize:
            image, mean, std = transforms.normalize_instance(image, eps=1e-11)
            if CLAMP:
                image = image.clamp(-6, 6)
        else:
            mean = -1.0
            std = -1.0

        # Normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target_train = target
            if self.normalize:
                target_train = transforms.normalize(target,
                                                    mean,
                                                    std,
                                                    eps=1e-11)
                if CLAMP:
                    target_train = target_train.clamp(
                        -6, 6
                    )  # Return target (for viz) and target_clamped (for training)
            norm = attrs['norm'].astype(np.float32)
        else:
            target_train = []
            target = []
            norm = -1.0
        image_updated = []
        if os.path.exists(
                '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_train/'
                + fname):
            updated_fname = '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_train/' + fname
            with h5py.File(updated_fname, 'r') as data:
                image_updated = data['reconstruction'][slice]
                image_updated = transforms.to_tensor(image_updated)
        elif os.path.exists(
                '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_val/'
                + fname):
            updated_fname = '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_val/' + fname
            with h5py.File(updated_fname, 'r') as data:
                image_updated = data['reconstruction'][slice]
                image_updated = transforms.to_tensor(image_updated)
        elif os.path.exists(
                '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_test/'
                + fname):
            updated_fname = '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_test/' + fname
            with h5py.File(updated_fname, 'r') as data:
                image_updated = data['reconstruction'][slice]
                image_updated = transforms.to_tensor(image_updated)

        return image, target_train, mean, std, norm, target, image_updated