Beispiel #1
0
def evaluate(args, recons_key):
    metrics = Metrics(METRIC_FUNCS)

    for tgt_file in args.target_path.iterdir():
        with h5py.File(tgt_file, "r") as target, h5py.File(
                args.predictions_path / tgt_file.name, "r") as recons:
            if args.acquisition and args.acquisition != target.attrs[
                    "acquisition"]:
                continue

            if args.acceleration and target.attrs[
                    "acceleration"] != args.acceleration:
                continue

            target = target[recons_key][()]
            recons = recons["reconstruction"][()]
            target = transforms.center_crop(
                target, (target.shape[-1], target.shape[-1]))
            recons = transforms.center_crop(
                recons, (target.shape[-1], target.shape[-1]))

            recons = recons.reshape(-1, target.shape[-1], target.shape[-1])
            metrics.push(target, recons)

    return metrics
Beispiel #2
0
def cs_total_variation(args, kspace, reg_wt, crop_size, num_low_freqs):
    """
    Run ESPIRIT coil sensitivity estimation and Total Variation Minimization
    based reconstruction algorithm using the BART toolkit.

    Args:
        args (argparse.Namespace): Arguments including ESPIRiT parameters.
        reg_wt (float): Regularization parameter.
        crop_size (tuple): Size to crop final image to.
    
    Returns:
        np.array: Reconstructed image.
    """
    if args.challenge == "singlecoil":
        kspace = kspace.unsqueeze(0)

    kspace = kspace.permute(1, 2, 0, 3).unsqueeze(0)
    kspace = tensor_to_complex_np(kspace)

    # estimate sensitivity maps
    if num_low_freqs is None:
        sens_maps = bart.bart(1, f"ecalib -d0 -m1", kspace)
    else:
        sens_maps = bart.bart(1, f"ecalib -d0 -m1 -r {num_low_freqs}", kspace)

    # use Total Variation Minimization to reconstruct the image
    pred = bart.bart(1, f"pics -d0 -S -R T:7:0:{reg_wt} -i {args.num_iters}",
                     kspace, sens_maps)
    pred = torch.from_numpy(np.abs(pred[0]))

    # check for FLAIR 203
    if pred.shape[1] < crop_size[1]:
        crop_size = (pred.shape[1], pred.shape[1])

    return T.center_crop(pred, crop_size)
def to_cropped_image(masked_kspace, target, attrs):
    # inverse Fourier transform to get zero filled solution
    image = fastmri.ifft2c(masked_kspace)

    # crop input to correct size
    if target is not None:
        crop_size = (target.shape[-2], target.shape[-1])
    else:
        crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

    # check for FLAIR 203
    if image.shape[-2] < crop_size[1]:
        crop_size = (image.shape[-2], image.shape[-2])

    image = T.complex_center_crop(image, crop_size)

    # absolute value
    image = fastmri.complex_abs(image)

    # normalize input
    image, mean, std = T.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)

    # normalize target
    if target is not None:
        if isinstance(target, np.ndarray):
            target = T.to_tensor(target)
        target = T.center_crop(target, crop_size)
        target = T.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
    else:
        target = torch.Tensor([0])

    return image, target
def run_varnet_model(batch, model, device):
    crop_size = batch.crop_size

    output = model(batch.masked_kspace.to(device), batch.mask.to(device)).cpu()

    # detect FLAIR 203
    if output.shape[-1] < crop_size[1]:
        crop_size = (output.shape[-1], output.shape[-1])

    output = T.center_crop(output, crop_size)[0]

    return output, int(batch.slice_num[0]), batch.fname[0]
Beispiel #5
0
def run_varnet_model(batch, model, device):
    masked_kspace, sensitivity_map, mask, _, fname, slice_num, _, crop_size = batch
    crop_size = crop_size[0]  # always have a batch size of 1 for varnet

    output = model(masked_kspace.to(device), sensitivity_map.to(device), mask.to(device)).cpu()

    # detect FLAIR 203
    if output.shape[-1] < crop_size[1]:
        crop_size = (output.shape[-1], output.shape[-1])

    output = T.center_crop(output, crop_size)[0]

    return output, int(slice_num[0]), fname[0]
def run_inference(checkpoint, data_path, output_path):
    varnet = VarNet()
    load_state_dict = torch.load(checkpoint)["state_dict"]
    state_dict = {}
    for k, v in load_state_dict.items():
        if "varnet" in k:
            state_dict[k[len("varnet."):]] = v

    varnet.load_state_dict(state_dict)
    varnet = varnet.eval()

    data_transform = DataTransform()

    dataset = SliceDataset(
        root=data_path,
        transform=data_transform,
        challenge="multicoil",
    )
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)

    start_time = time.perf_counter()
    outputs = defaultdict(list)

    for batch in tqdm(dataloader, desc="Running inference..."):
        masked_kspace, mask, _, fname, slice_num, _, crop_size = batch
        crop_size = crop_size[0]  # always have a batch size of 1 for varnet
        fname = fname[0]  # always have batch size of 1 for varnet

        with torch.no_grad():
            try:
                device = torch.device("cuda")

                output = run_model(masked_kspace, mask, varnet, fname, device)
            except RuntimeError:
                print("running on cpu")
                device = torch.device("cpu")

                output = run_model(masked_kspace, mask, varnet, fname, device)

            output = T.center_crop(output, crop_size)[0]

        outputs[fname].append((slice_num, output))

    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

    fastmri.save_reconstructions(outputs, output_path / "reconstructions")

    end_time = time.perf_counter()

    print(f"elapsed time for {len(dataloader)} slices: {end_time-start_time}")
    def test_step(self, batch, batch_idx):
        masked_kspace, mask, _, fname, slice_num, _, crop_size = batch
        crop_size = crop_size[0]  # always have a batch size of 1 for varnet

        output = self(masked_kspace, mask)

        # check for FLAIR 203
        if output.shape[-1] < crop_size[1]:
            crop_size = (output.shape[-1], output.shape[-1])

        output = transforms.center_crop(output, crop_size)

        return {
            "fname": fname,
            "slice": slice_num,
            "output": output.cpu().numpy(),
        }
Beispiel #8
0
    def test_step(self, batch, batch_idx):
        output = self(batch.masked_kspace, batch.mask,
                      batch.num_low_frequencies)

        # check for FLAIR 203
        if output.shape[-1] < batch.crop_size[1]:
            crop_size = (output.shape[-1], output.shape[-1])
        else:
            crop_size = batch.crop_size

        output = transforms.center_crop(output, crop_size)

        return {
            "fname": batch.fname,
            "slice": batch.slice_num,
            "output": output.cpu().numpy(),
        }
def test_center_crop(shape, target_shape):
    x = create_input(shape)
    out_torch = transforms.center_crop(x, target_shape).numpy()

    assert list(out_torch.shape) == target_shape
Beispiel #10
0
    def __call__(self, kspace, mask, target, attrs, fname, slice_num):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows,
                cols, 2) for multi-coil data or (rows, cols, 2) for single coil
                data.
            mask (numpy.array): Mask from the test dataset.
            target (numpy.array): Target image.
            attrs (dict): Acquisition related information stored in the HDF5
                object.
            fname (str): File name.
            slice_num (int): Serial number of the slice.

        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch
                    Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                fname (str): File name.
                slice_num (int): Serial number of the slice.
        """
        kspace = transforms.to_tensor(kspace)


        image = fastmri.ifft2c(kspace)

        # crop input to correct size
        if target is not None:
            crop_size = (target.shape[-2], target.shape[-1])
        else:
            crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        # check for sFLAIR 203
        if image.shape[-2] < crop_size[1]:
            crop_size = (image.shape[-2], image.shape[-2])
   
        image = transforms.complex_center_crop(image, crop_size)

        #getLR
        imgfft = fastmri.fft2c(image)
        imgfft = transforms.complex_center_crop(imgfft,(160,160))
        LR_image = fastmri.ifft2c(imgfft)

        # absolute value
        LR_image = fastmri.complex_abs(LR_image)

        # normalize input
        LR_image, mean, std = transforms.normalize_instance(LR_image, eps=1e-11)
        LR_image = LR_image.clamp(-6, 6)

        # normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])

        return LR_image, target, mean, std, fname, slice_num
Beispiel #11
0
    def __call__(self, kspace, mask, target, attrs, fname, slice_num):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows,
                cols, 2) for multi-coil data or (rows, cols, 2) for single coil
                data.
            mask (numpy.array): Mask from the test dataset.
            target (numpy.array): Target image.
            attrs (dict): Acquisition related information stored in the HDF5
                object.
            fname (str): File name.
            slice_num (int): Serial number of the slice.

        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch
                    Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                fname (str): File name.
                slice_num (int): Serial number of the slice.
        """
        kspace = transforms.to_tensor(kspace)

        # apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = transforms.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # inverse Fourier transform to get zero filled solution
        image = fastmri.ifft2c(masked_kspace)

        # crop input to correct size
        if target is not None:
            crop_size = (target.shape[-2], target.shape[-1])
        else:
            crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        # check for FLAIR 203
        if image.shape[-2] < crop_size[1]:
            crop_size = (image.shape[-2], image.shape[-2])

        image = transforms.complex_center_crop(image, crop_size)

        # absolute value
        image = fastmri.complex_abs(image)

        # apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == "multicoil":
            image = fastmri.rss(image)

        # normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        # normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])

        return image, target, mean, std, fname, slice_num
Beispiel #12
0
​

slice = 20

crop_size = (320, 320)

device = 'cuda'

​

target = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(kspace[slice], axes=(-2, -1)), axes=(-2, -1)), axes=(-2, -1))

target = target / np.max(np.abs(target))

target = np.sqrt(np.sum(T.center_crop(target, crop_size) ** 2, 0))

​

crop_size = (320, 320)

mask_func = create_mask_for_mask_type(mask_type_str="random", center_fractions=[0.08], accelerations=[4])

​

_kspace = T.to_tensor(kspace)[slice]

masked_kspace, mask = T.apply_mask(_kspace, mask_func)

​
    def __call__(
        self,
        kspace: np.ndarray,
        mask: np.ndarray,
        target: np.ndarray,
        attrs: Dict,
        fname: str,
        slice_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str,
               int, float]:
        """
        Args:
            kspace: Input k-space of shape (num_coils, rows, cols) for
                multi-coil data or (rows, cols) for single coil data.
            mask: Mask from the test dataset.
            target: Target image.
            attrs: Acquisition related information stored in the HDF5 object.
            fname: File name.
            slice_num: Serial number of the slice.

        Returns:
            tuple containing:
                image: Zero-filled input image.
                target: Target image converted to a torch.Tensor.
                mean: Mean value used for normalization.
                std: Standard deviation value used for normalization.
                fname: File name.
                slice_num: Serial number of the slice.
        """
        kspace = T.to_tensor(kspace)

        # check for max value
        max_value = attrs["max"] if "max" in attrs.keys() else 0.0

        # apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = T.apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # inverse Fourier transform to get zero filled solution
        image = fastmri.ifft2c(masked_kspace)

        if not self.test_mode:
            # crop input to correct size
            if target is not None:
                crop_size = (target.shape[-2], target.shape[-1])
            else:
                crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        # check for FLAIR 203
        if self.test_mode or image.shape[-2] < crop_size[1]:
            crop_size = (image.shape[-2], image.shape[-2])

        image = T.complex_center_crop(image, crop_size)

        # absolute value
        image = fastmri.complex_abs(image)

        # apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == "multicoil":
            image = fastmri.rss(image)

        # normalize input
        image, mean, std = T.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        # normalize target
        if not self.test_mode and target is not None:
            target = T.to_tensor(target)
            target = T.center_crop(target, crop_size)
            target = T.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])

        return image, target, mean, std, fname, slice_num, max_value