def visualize(args, epoch, model, data_loader, writer):
    def save_image(image, tag):
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    with torch.no_grad():
        for iter, data in enumerate(data_loader):

            stacked_kspace_square, _, stacked_image_square, _, _ = data
            # ksp,input, target, mean, std, norm = data
            # input = input.unsqueeze(1).to(args.device)
            # target = target.to(args.device)
            # target = target.unsqueeze(1)
            # print("target",target.shape)
            ksp_shifted_mc = transforms.ifftshift(stacked_kspace_square,
                                                  dim=(-2, -1))
            output = model(ksp_shifted_mc.cuda())
            # print("output",output.shape)
            output_rss = stack_to_rss(output)
            # output_rss = output_rss.unsqueeze(1)
            target = stack_to_rss(stacked_image_square)  #.unsqueeze(1)
            # print("output_rss",output_rss.shape)

            #FIXME - not ready yet
            save_image(target, 'Target')
            save_image(output_rss, 'Reconstruction')
            save_image((target - output_rss.cpu()), 'Error')
            break
    def perform(self, out_img_cmplx, ksp, sens, mask):

        x = T.complex_multiply(out_img_cmplx[..., 0].unsqueeze(1),
                               out_img_cmplx[..., 1].unsqueeze(1),
                               sens[..., 0], sens[..., 1])

        k = (torch.fft(x, 2, normalized=True)).squeeze(1)
        k_shift = T.ifftshift(k, dim=(-3, -2))

        sr = 0.85
        Nz = k_shift.shape[-2]
        Nz_sampled = int(np.ceil(Nz * sr))
        k_shift[:, :, :, Nz_sampled:, :] = 0

        v = self.noise_lvl
        if v is not None:  # noisy case
            # out = (1 - mask) * k + mask * (k + v * k0) / (1 + v)
            out = (1 - mask) * k_shift + mask * (v * k_shift + (1 - v) * ksp)

        else:

            out = (1 - mask) * k_shift + mask * ksp

        x = torch.ifft(out, 2, normalized=True)

        Sx = T.complex_multiply(x[..., 0], x[..., 1], sens[..., 0],
                                -sens[..., 1]).sum(dim=1)

        Ss = T.complex_multiply(sens[..., 0], sens[..., 1], sens[..., 0],
                                -sens[..., 1]).sum(dim=1)

        return Sx, Ss
def evaluate(args, epoch, model, data_loader, writer):
    model.eval()
    losses = []
    start = time.perf_counter()
    with torch.no_grad():
        for iter, data in enumerate(tqdm(data_loader)):
            stacked_kspace_square, _, stacked_image_square, _, _ = data
            # input = input.unsqueeze(1).to(args.device)
            ksp_shifted_mc = transforms.ifftshift(stacked_kspace_square,
                                                  dim=(-2, -1))
            output = model(ksp_shifted_mc.cuda())  #.squeeze(1)

            # mean = mean.unsqueeze(1).unsqueeze(2).to(args.device)
            # std = std.unsqueeze(1).unsqueeze(2).to(args.device)

            # to be done only on magnitude
            # target = target * std + mean
            # output = output * std + mean

            # norm = norm.unsqueeze(1).unsqueeze(2).to(args.device)
            # norm = 1 # can't divide directly with complex
            # loss = F.mse_loss(output , target / norm, size_average=False)

            loss = F.mse_loss(output,
                              stacked_image_square.cuda(),
                              reduction='sum')
            losses.append(loss.item())
        if writer is not None:
            writer.add_scalar('Dev_Loss', np.mean(losses), epoch)
    return np.mean(losses), time.perf_counter() - start
Пример #4
0
def data_for_training(rawdata, sensitivity, mask_func, norm=True):
    ''' normalize each slice using complex absolute max value'''

    rawdata = T.to_tensor(np.complex64(rawdata.transpose(2, 0, 1)))

    sensitivity = T.to_tensor(sensitivity.transpose(2, 0, 1))

    coils, Ny, Nx, ps = rawdata.shape

    # shift data
    shift_kspace = rawdata
    x, y = np.meshgrid(np.arange(1, Nx + 1), np.arange(1, Ny + 1))
    adjust = (-1)**(x + y)
    shift_kspace = T.ifftshift(shift_kspace, dim=(
        -3, -2)) * torch.from_numpy(adjust).view(1, Ny, Nx, 1).float()

    # apply masks
    shape = np.array(shift_kspace.shape)
    shape[:-3] = 1
    mask = mask_func(shape)
    mask = T.ifftshift(mask)  # shift mask

    # undersample
    masked_kspace = torch.where(mask == 0, torch.Tensor([0]), shift_kspace)
    masks = mask.repeat(coils, Ny, 1, ps)

    img_gt, img_und = T.ifft2(shift_kspace), T.ifft2(masked_kspace)

    if norm:
        # perform k space raw data normalization
        # during inference there is no ground truth image so use the zero-filled recon to normalize
        norm = T.complex_abs(img_und).max()
        if norm < 1e-6: norm = 1e-6
        # normalized recon
    else:
        norm = 1

    # normalize data to learn more effectively
    img_gt, img_und = img_gt / norm, img_und / norm

    rawdata_und = masked_kspace / norm  # faster

    sense_gt = cobmine_all_coils(img_gt, sensitivity)

    sense_und = cobmine_all_coils(img_und, sensitivity)

    return sense_und, sense_gt, rawdata_und, masks, sensitivity
Пример #5
0
def data_for_training(rawdata, sensitivity, mask, norm=True):
    ''' normalize each slice using complex absolute max value'''

    coils, Ny, Nx, ps = rawdata.shape

    # shift data
    shift_kspace = rawdata
    x, y = np.meshgrid(np.arange(1, Nx + 1), np.arange(1, Ny + 1))
    adjust = (-1)**(x + y)
    shift_kspace = T.ifftshift(shift_kspace, dim=(
        -3, -2)) * torch.from_numpy(adjust).view(1, Ny, Nx, 1).float()

    #masked_kspace = torch.where(mask == 0, torch.Tensor([0]), shift_kspace)
    mask = T.ifftshift(mask)
    mask = mask.unsqueeze(0).unsqueeze(-1).float()
    mask = mask.repeat(coils, 1, 1, ps)

    masked_kspace = shift_kspace * mask

    img_gt, img_und = T.ifft2(shift_kspace), T.ifft2(masked_kspace)

    if norm:
        # perform k space raw data normalization
        # during inference there is no ground truth image so use the zero-filled recon to normalize
        norm = T.complex_abs(img_und).max()
        if norm < 1e-6: norm = 1e-6
        # normalized recon
    else:
        norm = 1

    # normalize data to learn more effectively
    img_gt, img_und = img_gt / norm, img_und / norm

    rawdata_und = masked_kspace / norm  # faster

    sense_gt = cobmine_all_coils(img_gt, sensitivity)

    sense_und = cobmine_all_coils(img_und, sensitivity)

    sense_und_kspace = T.fft2(sense_und)

    return sense_und, sense_gt, sense_und_kspace, rawdata_und, mask, sensitivity
Пример #6
0
def imagenormalize(data, divisor=None):
    """kspace generated by normalizing image space"""
    #getting image from masked data
    image = transforms.ifft2(data)
    #normalizing the image
    nimage, divisor = normalize(image, divisor)
    #getting kspace data from normalized image
    data = transforms.ifftshift(image, dim=(-3, -2))
    data = torch.fft(data, 2)
    data = transforms.fftshift(data, dim=(-3, -2))
    return data, divisor
Пример #7
0
def mnormalize(masked_kspace):
    #getting image from masked data
    image = transforms.ifft2(masked_kspace)
    #normalizing the image
    nimage, mean, std = transforms.normalize_instance(image, eps=1e-11)
    #getting kspace data from normalized image
    maksed_kspace_fni = transforms.ifftshift(nimage, dim=(-3, -2))
    maksed_kspace_fni = torch.fft(maksed_kspace_fni, 2)
    maksed_kspace_fni = transforms.fftshift(maksed_kspace_fni, dim=(-3, -2))
    maksed_kspace_fni, mean, std = transforms.normalize_instance(masked_kspace,
                                                                 eps=1e-11)
    return maksed_kspace_fni, mean, std
Пример #8
0
def onormalize(original_kspace, mean, std, eps=1e-11):
    #getting image from masked data
    image = transforms.ifft2(original_kspace)
    #normalizing the image
    nimage = transforms.normalize(image, mean, std, eps=1e-11)
    #getting kspace data from normalized image
    original_kspace_fni = transforms.ifftshift(nimage, dim=(-3, -2))
    original_kspace_fni = torch.fft(original_kspace_fni, 2)
    original_kspace_fni = transforms.fftshift(original_kspace_fni,
                                              dim=(-3, -2))
    original_kspace_fni = transforms.normalize(original_kspace,
                                               mean,
                                               std,
                                               eps=1e-11)
    return original_kspace_fni
Пример #9
0
def nkspacetoimage(args, kspace_fni, mean, std, eps=1e-11):
    #nkspace to image
    assert kspace_fni.size(-1) == 2
    image = transforms.ifftshift(kspace_fni, dim=(-3, -2))
    image = torch.ifft(image, 2)
    image = transforms.fftshift(image, dim=(-3, -2))
    #denormalizing the nimage
    image = (image * std) + mean
    image = image[0]

    image = transforms.complex_center_crop(image,
                                           (args.resolution, args.resolution))
    # Absolute value
    image = transforms.complex_abs(image)
    # Normalize input
    image, mean, std = transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)
    return image
def train_epoch(args, epoch, model, data_loader, optimizer, writer):
    model.train()
    avg_loss = 0.
    start_epoch = start_iter = time.perf_counter()
    global_step = epoch * len(data_loader)
    for iter, data in (enumerate(tqdm(data_loader))):

        stacked_kspace_square, _, stacked_image_square, _, _ = data
        # input = input.unsqueeze(1).to(args.device)
        # target = target.to(args.device)
        ksp_shifted_mc = transforms.ifftshift(stacked_kspace_square,
                                              dim=(-2, -1))
        output = model(ksp_shifted_mc.cuda())  #.squeeze(1)
        # print("output",output.shape)
        # out_chans = stack_to_chans(output)
        # out_ksp = transforms.fft2(out_chans)

        # print("out_chans",out_ksp.shape)

        loss = F.l1_loss(output, stacked_image_square.cuda(), reduction='sum')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss = 0.99 * avg_loss + 0.01 * loss.item(
        ) if iter > 0 else loss.item()
        if writer is not None:
            writer.add_scalar('TrainLoss', loss.item(), global_step + iter)

        if iter % args.report_interval == 0:
            logging.info(
                f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] '
                f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} '
                f'Time = {time.perf_counter() - start_iter:.4f}s', )
        start_iter = time.perf_counter()
    return avg_loss, time.perf_counter() - start_epoch
Пример #11
0
                                    self.translation()).float()


class C3Convert:
    def __init__(self, shp=(320, 320)):
        self.c3m = transforms.c3_torch(shp)  # ksp.shape[-2:]

    def apply(self, ksp):
        # expect (bat, w, h, 2)
        return ksp * self.c3m


# torch fft requires (bat, w,h, 2)

ifft_c3 = lambda kspc3: torch.ifft(
    transforms.ifftshift(kspc3, dim=(-3, -2)), 2, normalized=True)

fft_c3 = lambda im: transforms.fftshift(torch.fft(im, 2, normalized=True),
                                        dim=(-3, -2))


class NoTransform:
    def __init__(self):
        pass

    def __call__(self, kspace, target, attrs, fname, slice):
        return kspace


class BasicMaskingTransform:
    def __init__(self,
Пример #12
0
def test_ifftshift(shape):
    input = np.arange(np.product(shape)).reshape(shape)
    out_torch = transforms.ifftshift(torch.from_numpy(input)).numpy()
    out_numpy = np.fft.ifftshift(input)
    assert np.allclose(out_torch, out_numpy)
Пример #13
0
from data import transforms as T


def c3_multiplier_npy(shape=(320,320)):
    shp = (shape[0],shape[1])

    mul_mat=np.resize([1,-1],shp)
    
    return mul_mat * mul_mat.T


def c3_torch(shp): 
    c3m = c3_multiplier_npy(shp)
    return torch.from_numpy(np.dstack((c3m,c3m))).float()

ifft_c3 = lambda kspc3: torch.ifft(T.ifftshift(kspc3,dim=(-3,-2)),2,normalized=True)

fft_c3 = lambda im: T.fftshift(torch.fft(im,2,normalized=True),dim=(-3,-2))


shp = (360,360)


c3m = c3_torch(shp)


def tosquare(ksp,shp):
    rec = T.ifft2(ksp)
    sz = rec.shape
    
    return c3m * T.fft2(T.complex_center_crop(rec,shp)) * 100000
Пример #14
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        gt = transforms.ifft2(kspace)
        gt = transforms.complex_center_crop(gt, (self.resolution, self.resolution))
        kspace = transforms.fft2(gt)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        masked_kspace = transforms.fft2_nshift(image)
        # Crop input image
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image_mod = transforms.complex_abs(image).max()
        image_r = image[:, :, 0]*6.0/image_mod
        image_i = image[:, :, 1]*6.0/image_mod
        # image_r = image[:, :, 0]
        # image_i = image[:, :, 1]
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input

        image = np.stack((image_r, image_i), axis=-1)
        image = image.transpose((2, 0, 1))
        image = transforms.to_tensor(image)

        target = transforms.ifft2(kspace)
        target = transforms.complex_center_crop(target, (self.resolution, self.resolution))
        # Normalize target
        target_r = target[:, :, 0]*6.0/image_mod
        target_i = target[:, :, 1]*6.0/image_mod
        # target_r = target[:, :, 0]
        # target_i = target[:, :, 1]

        target = np.stack((target_r, target_i), axis=-1)
        target = target.transpose((2, 0, 1))
        target = transforms.to_tensor(target)

        image_mod = np.stack((image_mod, image_mod), axis=0)
        image_mod = transforms.to_tensor(image_mod)

        norm = attrs['norm'].astype(np.float32)
        norm = np.stack((norm, norm), axis=-1)
        norm = transforms.to_tensor(norm)

        mask = mask.expand(kspace.shape)
        mask = mask.transpose(0, 2).transpose(1, 2)
        mask = transforms.ifftshift(mask)

        masked_kspace = masked_kspace.transpose(0, 2).transpose(1, 2)

        return image, target