Ejemplo n.º 1
0
def test_complex_abs(shape):
    shape = shape + [2]
    input = create_input(shape)
    out_torch = transforms.complex_abs(input).numpy()
    input_numpy = utils.tensor_to_complex_np(input)
    out_numpy = np.abs(input_numpy)
    assert np.allclose(out_torch, out_numpy)
Ejemplo n.º 2
0
def cs_total_variation(args, kspace, acquisition, acceleration, num_low_freqs):
    """
    Run ESPIRIT coil sensitivity estimation and Total Variation Minimization based
    reconstruction algorithm using the BART toolkit.
    """

    if acquisition not in REG_PARAM[args.challenge]:
        raise ValueError(f'Invalid acquisition protocol: {acquisition}')
    if acceleration not in {4, 8}:
        raise ValueError(f'Invalid acceleration factor: {acceleration}')

    if args.challenge == 'singlecoil':
        kspace = kspace.unsqueeze(0)
    kspace = kspace.permute(1, 2, 0, 3).unsqueeze(0)
    kspace = tensor_to_complex_np(kspace)

    # Estimate sensitivity maps
    sens_maps = bart.bart(1, f'ecalib -d0 -m1 -r {num_low_freqs}', kspace)

    # Use Total Variation Minimization to reconstruct the image
    reg_wt = REG_PARAM[args.challenge][acquisition][acceleration]
    pred = bart.bart(1, f'pics -d0 -S -R T:7:0:{reg_wt} -i {args.num_iters}',
                     kspace, sens_maps)
    pred = torch.from_numpy(np.abs(pred[0]))

    # Crop the predicted image to selected resolution if bigger
    smallest_width = min(args.resolution, pred.shape[-1])
    smallest_height = min(args.resolution, pred.shape[-2])
    return transforms.center_crop(pred, (smallest_height, smallest_width))
Ejemplo n.º 3
0
def test_ifft2(shape):
    shape = shape + [2]
    input = create_input(shape)
    out_torch = transforms.ifft2(input).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = utils.tensor_to_complex_np(input)
    input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
    out_numpy = np.fft.ifft2(input_numpy, norm='ortho')
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))
    assert np.allclose(out_torch, out_numpy)
Ejemplo n.º 4
0
def cs_total_variation(args, kspace):
    """
    Run ESPIRIT coil sensitivity estimation and Total Variation Minimization based
    reconstruction algorithm using the BART toolkit.
    """

    if args.challenge == 'singlecoil':
        kspace = kspace.unsqueeze(0)
    kspace = kspace.permute(1, 2, 0, 3).unsqueeze(0)
    kspace = tensor_to_complex_np(kspace)

    # Estimate sensitivity maps
    sens_maps = bart.bart(1, f'ecalib -d0 -m1', kspace)

    # Use Total Variation Minimization to reconstruct the image
    pred = bart.bart(
        1, f'pics -d0 -S -R T:7:0:{args.reg_wt} -i {args.num_iters}', kspace,
        sens_maps)
    pred = torch.from_numpy(np.abs(pred[0]))

    # Crop the predicted image to the correct size
    return transforms.center_crop(pred, (args.resolution, args.resolution))
Ejemplo n.º 5
0
                                         accelerations=[acc])
                    img_und, img_gt, rawdata_und, masks, sensitivity = data_for_training(
                        rawdata, coil_sensitivities, mask_func)

                    # add batch dimension
                    batch_img_und = img_und.unsqueeze(0).to(device)
                    batch_rawdata_und = rawdata_und.unsqueeze(0).to(device)
                    batch_masks = masks.unsqueeze(0).to(device)
                    batch_sensitivities = sensitivity.unsqueeze(0).to(device)

                    # deploy the model
                    rec = rec_net(batch_img_und, batch_rawdata_und,
                                  batch_masks, batch_sensitivities)

                    # convert to complex
                    batch_recon = tensor_to_complex_np(rec.to('cpu'))
                    batch_img_und = tensor_to_complex_np(
                        batch_img_und.to('cpu'))
                    img_gt = tensor_to_complex_np(img_gt.to('cpu'))

                    # squeeze batch dimension
                    batch_recon = np.squeeze(batch_recon, axis=0)
                    batch_img_und = np.squeeze(batch_img_und, axis=0)

                    output.append(batch_recon)
                    input0.append(batch_img_und)
                    target.append(img_gt)
                    normalization.append(np.max(np.abs(batch_img_und)))

                # postprocess images
                output = np.asarray(output)