Exemple #1
0
def save_zero_filled(data_dir, out_dir, which_challenge):
    reconstructions = {}

    for f in data_dir.iterdir():
        with h5py.File(f, "r") as hf:
            enc = ismrmrd.xsd.CreateFromDocument(hf["ismrmrd_header"][()]).encoding[0]
            masked_kspace = transforms.to_tensor(hf["kspace"][()])

            # extract target image width, height from ismrmrd header
            crop_size = (enc.reconSpace.matrixSize.x, enc.reconSpace.matrixSize.y)

            # inverse Fourier Transform to get zero filled solution
            image = fastmri.ifft2c(masked_kspace)

            # check for FLAIR 203
            if image.shape[-2] < crop_size[1]:
                crop_size = (image.shape[-2], image.shape[-2])

            # crop input image
            image = transforms.complex_center_crop(image, crop_size)

            # absolute value
            image = fastmri.complex_abs(image)

            # apply Root-Sum-of-Squares if multicoil data
            if which_challenge == "multicoil":
                image = fastmri.rss(image, dim=1)

            reconstructions[f.name] = image

    fastmri.save_reconstructions(reconstructions, out_dir)
def save_zero_filled(data_dir, out_dir, which_challenge):
    reconstructions = {}

    for fname in tqdm(list(data_dir.glob("*.h5"))):
        with h5py.File(fname, "r") as hf:
            et_root = etree.fromstring(hf["ismrmrd_header"][()])
            masked_kspace = transforms.to_tensor(hf["kspace"][()])

            # extract target image width, height from ismrmrd header
            enc = ["encoding", "encodedSpace", "matrixSize"]
            crop_size = (
                int(et_query(et_root, enc + ["x"])),
                int(et_query(et_root, enc + ["y"])),
            )

            # inverse Fourier Transform to get zero filled solution
            image = fastmri.ifft2c(masked_kspace)

            # check for FLAIR 203
            if image.shape[-2] < crop_size[1]:
                crop_size = (image.shape[-2], image.shape[-2])

            # crop input image
            image = transforms.complex_center_crop(image, crop_size)

            # absolute value
            image = fastmri.complex_abs(image)

            # apply Root-Sum-of-Squares if multicoil data
            if which_challenge == "multicoil":
                image = fastmri.rss(image, dim=1)

            reconstructions[fname.name] = image

    fastmri.save_reconstructions(reconstructions, out_dir)
Exemple #3
0
def run_inference(challenge, state_dict_file, data_path, output_path, device):
    model = Unet(in_chans=1,
                 out_chans=1,
                 chans=256,
                 num_pool_layers=4,
                 drop_prob=0.0)
    # download the state_dict if we don't have it
    if state_dict_file is None:
        if not Path(MODEL_FNAMES[challenge]).exists():
            url_root = UNET_FOLDER
            download_model(url_root + MODEL_FNAMES[challenge],
                           MODEL_FNAMES[challenge])

        state_dict_file = MODEL_FNAMES[challenge]

    model.load_state_dict(torch.load(state_dict_file))
    model = model.eval()

    # data loader setup
    if "_mc" in challenge:
        data_transform = T.UnetDataTransform(which_challenge="multicoil")
    else:
        data_transform = T.UnetDataTransform(which_challenge="singlecoil")

    if "_mc" in challenge:
        dataset = SliceDataset(
            root=data_path,
            transform=data_transform,
            challenge="multicoil",
        )
    else:
        dataset = SliceDataset(
            root=data_path,
            transform=data_transform,
            challenge="singlecoil",
        )
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)

    # run the model
    start_time = time.perf_counter()
    outputs = defaultdict(list)
    model = model.to(device)

    for batch in tqdm(dataloader, desc="Running inference"):
        with torch.no_grad():
            output, slice_num, fname = run_unet_model(batch, model, device)

        outputs[fname].append((slice_num, output))

    # save outputs
    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

    fastmri.save_reconstructions(outputs, output_path / "reconstructions")

    end_time = time.perf_counter()

    print(f"Elapsed time for {len(dataloader)} slices: {end_time-start_time}")
Exemple #4
0
def save_outputs(outputs, output_path):
    """Saves reconstruction outputs to output_path."""
    reconstructions = defaultdict(list)

    for fname, slice_num, pred in outputs:
        reconstructions[fname].append((slice_num, pred))

    reconstructions = {
        fname: np.stack([pred for _, pred in sorted(slice_preds)])
        for fname, slice_preds in reconstructions.items()
    }

    fastmri.save_reconstructions(reconstructions, output_path)
def run_inference(checkpoint, data_path, output_path):
    varnet = VarNet()
    load_state_dict = torch.load(checkpoint)["state_dict"]
    state_dict = {}
    for k, v in load_state_dict.items():
        if "varnet" in k:
            state_dict[k[len("varnet."):]] = v

    varnet.load_state_dict(state_dict)
    varnet = varnet.eval()

    data_transform = DataTransform()

    dataset = SliceDataset(
        root=data_path,
        transform=data_transform,
        challenge="multicoil",
    )
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)

    start_time = time.perf_counter()
    outputs = defaultdict(list)

    for batch in tqdm(dataloader, desc="Running inference..."):
        masked_kspace, mask, _, fname, slice_num, _, crop_size = batch
        crop_size = crop_size[0]  # always have a batch size of 1 for varnet
        fname = fname[0]  # always have batch size of 1 for varnet

        with torch.no_grad():
            try:
                device = torch.device("cuda")

                output = run_model(masked_kspace, mask, varnet, fname, device)
            except RuntimeError:
                print("running on cpu")
                device = torch.device("cpu")

                output = run_model(masked_kspace, mask, varnet, fname, device)

            output = T.center_crop(output, crop_size)[0]

        outputs[fname].append((slice_num, output))

    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

    fastmri.save_reconstructions(outputs, output_path / "reconstructions")

    end_time = time.perf_counter()

    print(f"elapsed time for {len(dataloader)} slices: {end_time-start_time}")
Exemple #6
0
    def test_epoch_end(self, test_logs):
        outputs = defaultdict(list)

        for log in test_logs:
            for i, (fname, slice) in enumerate(zip(log["fname"],
                                                   log["slice"])):
                outputs[fname].append((slice, log["output"][i]))

        for fname in outputs:
            outputs[fname] = np.stack(
                [out for _, out in sorted(outputs[fname])])

        fastmri.save_reconstructions(outputs,
                                     self.exp_dir / self.exp_name / "bicubic")

        return dict()
Exemple #7
0
def run_inference(challenge, state_dict_file, data_path, output_path, mask_func, use_sens_net, device):
    model = VarNet(num_cascades=12, pools=4, chans=18, use_sens_net=use_sens_net, sens_pools=4, sens_chans=8)
    # download the state_dict if we don't have it
    if state_dict_file is None:
        if not Path(MODEL_FNAMES[challenge]).exists():
            url_root = VARNET_FOLDER
            download_model(url_root + MODEL_FNAMES[challenge], MODEL_FNAMES[challenge])

        state_dict_file = MODEL_FNAMES[challenge]

    model.load_state_dict(torch.load(state_dict_file))
    model = model.eval()

    # data loader setup
    if mask_func is None:
        data_transform = T.VarNetDataTransform()
    else:
        data_transform = T.VarNetDataTransform(mask_func=mask_func)

    dataset = SliceDataset(
        root=data_path, transform=data_transform, challenge="multicoil"
    )
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)

    # run the model
    start_time = time.perf_counter()
    outputs = defaultdict(list)
    model = model.to(device)

    for batch in tqdm(dataloader, desc="Running inference"):
        with torch.no_grad():
            output, slice_num, fname = run_varnet_model(batch, model, device)

        outputs[fname].append((slice_num, output))

    # save outputs
    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

    fastmri.save_reconstructions(outputs, output_path / "reconstructions")

    end_time = time.perf_counter()

    print(f"Elapsed time for {len(dataloader)} slices: {end_time - start_time}")
    def test_epoch_end(self, test_logs):
        outputs = defaultdict(dict)

        # use dicts for aggregation to handle duplicate slices in ddp mode
        for log in test_logs:
            for i, (fname, slice_num) in enumerate(zip(log["fname"], log["slice"])):
                outputs[fname][int(slice_num.cpu())] = log["output"][i]

        # stack all the slices for each file
        for fname in outputs:
            outputs[fname] = np.stack(
                [out for _, out in sorted(outputs[fname].items())]
            )

        # pull the default_root_dir if we have a trainer, otherwise save to cwd
        if hasattr(self, "trainer"):
            save_path = pathlib.Path(self.trainer.default_root_dir) / "reconstructions"
        else:
            save_path = pathlib.Path.cwd() / "reconstructions"
        self.print(f"Saving reconstructions to {save_path}")

        fastmri.save_reconstructions(outputs, save_path)
Exemple #9
0
    def validation_epoch_end(self, val_logs):
        #assert val_logs[0]["output_im"].ndim == 3
        device = val_logs[0]["device"]

        # run the visualizations
        self._visualize_val(
            val_outputs=[x["output"].numpy() for x in val_logs],
            val_targets=[x["target"].numpy() for x in val_logs],
            val_inputs=[x["input"].numpy() for x in val_logs],
        )

        # aggregate losses
        losses = []
        outputs = defaultdict(list)
        targets = defaultdict(list)
        inputs = defaultdict(list)

        for val_log in val_logs:
            losses.append(val_log["val_loss"])
            for i, (fname, slice_ind) in enumerate(
                    zip(val_log["fname"], val_log["slice"])):
                # need to check for duplicate slices
                if slice_ind not in [s for (s, _) in outputs[int(fname)]]:
                    outputs[int(fname)].append(
                        (int(slice_ind), val_log["output"][i]))
                    targets[int(fname)].append(
                        (int(slice_ind), val_log["target"][i]))
                    inputs[int(fname)].append(
                        (int(slice_ind), val_log["input"][i]))

        # handle aggregation for distributed case with pytorch_lightning metrics
        metrics = dict(val_loss=0, nmse=0, ssim=0, psnr=0)
        for fname in outputs:
            output = torch.stack([out for _, out in sorted(outputs[fname])
                                  ]).numpy()
            target = torch.stack([tgt for _, tgt in sorted(targets[fname])
                                  ]).numpy()
            input = torch.stack([inn
                                 for _, inn in sorted(inputs[fname])]).numpy()

            metrics["nmse"] = metrics["nmse"] + evaluate.nmse(target, output)
            metrics["ssim"] = metrics["ssim"] + evaluate.ssim(target, output)
            metrics["psnr"] = metrics["psnr"] + evaluate.psnr(target, output)

        # currently ddp reduction requires everything on CUDA, so we'll do this manually
        metrics["nmse"] = self.NMSE(torch.tensor(metrics["nmse"]).to(device))
        metrics["ssim"] = self.SSIM(torch.tensor(metrics["ssim"]).to(device))
        metrics["psnr"] = self.PSNR(torch.tensor(metrics["psnr"]).to(device))
        metrics["val_loss"] = self.ValLoss(
            torch.sum(torch.stack(losses)).to(device))

        num_examples = torch.tensor(len(outputs)).to(device)
        tot_examples = self.TotExamples(num_examples)

        log_metrics = {
            f"metrics/{metric}": values / tot_examples
            for metric, values in metrics.items()
        }
        metrics = {
            metric: values / tot_examples
            for metric, values in metrics.items()
        }
        print(tot_examples, device, metrics)

        fastmri.save_reconstructions(inputs,
                                     self.exp_dir / self.exp_name / "bicubic")

        return dict(log=log_metrics, **metrics)
Exemple #10
0
def run_inference(challenge, state_dict_file, data_path, output_path, device,
                  mask, in_chans, out_chans, chans):
    if args.unet_module == "unet":
        model = Unet(in_chans=in_chans,
                     out_chans=out_chans,
                     chans=chans,
                     num_pool_layers=4,
                     drop_prob=0.0)
    elif args.unet_module == "nestedunet":
        model = NestedUnet(in_chans=in_chans,
                           out_chans=out_chans,
                           chans=chans,
                           num_pool_layers=4,
                           drop_prob=0.0)

    pretrained_dict = torch.load(state_dict_file, map_location=device)
    model_dict = model.state_dict()
    if args.fine_tuned_model:
        if 'state_dict' in pretrained_dict.keys():
            model_dict = {
                k: pretrained_dict["state_dict"][f"unet.{k}"]
                for k, _ in model_dict.items()
            }  # load from .ckpt
        elif 'unet' in pretrained_dict.keys():
            model_dict = {
                k: pretrained_dict["unet." + k]
                for k, v in model_dict.items()
            }  # load from .torch
        else:
            model_dict = {
                k: pretrained_dict[k]
                for k, v in model_dict.items()
            }  # load from .pt
    else:
        if args.unet_module == "unet":
            model_dict = {
                k: pretrained_dict["classy_state_dict"]["base_model"]["model"]
                ["trunk"]["_feature_blocks.unetblock." + k]
                for k, _ in model_dict.items()
            }
        elif args.unet_module == "nestedunet":
            model_dict = {
                k: pretrained_dict["classy_state_dict"]["base_model"]["model"]
                ["trunk"]["_feature_blocks.nublock." + k]
                for k, v in model_dict.items()
            }

    model.load_state_dict(model_dict)
    model = model.eval()

    # data loader setup
    if "_mc" in challenge:
        data_transform = T.UnetDataTransform(which_challenge="multicoil",
                                             mask_func=mask)
    else:
        data_transform = T.UnetDataTransform(which_challenge="singlecoil",
                                             mask_func=mask)

    if "_mc" in challenge:
        dataset = SliceDataset(
            root=data_path,
            transform=data_transform,
            challenge="multicoil",
        )
    else:
        dataset = SliceDataset(
            root=data_path,
            transform=data_transform,
            challenge="singlecoil",
        )
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)

    # run the model
    start_time = time.perf_counter()
    outputs = defaultdict(list)
    model = model.to(device)

    for batch in tqdm(dataloader, desc="Running inference"):
        with torch.no_grad():
            output, slice_num, fname = run_unet_model(batch, model, device)

        outputs[fname].append((slice_num, output))

    # save outputs
    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

    fastmri.save_reconstructions(outputs, output_path / "reconstructions")

    end_time = time.perf_counter()

    print(f"Elapsed time for {len(dataloader)} slices: {end_time-start_time}")