コード例 #1
0
def run_inference(challenge, state_dict_file, data_path, output_path, device):
    model = Unet(in_chans=1,
                 out_chans=1,
                 chans=256,
                 num_pool_layers=4,
                 drop_prob=0.0)
    # download the state_dict if we don't have it
    if state_dict_file is None:
        if not Path(MODEL_FNAMES[challenge]).exists():
            url_root = UNET_FOLDER
            download_model(url_root + MODEL_FNAMES[challenge],
                           MODEL_FNAMES[challenge])

        state_dict_file = MODEL_FNAMES[challenge]

    model.load_state_dict(torch.load(state_dict_file))
    model = model.eval()

    # data loader setup
    if "_mc" in challenge:
        data_transform = T.UnetDataTransform(which_challenge="multicoil")
    else:
        data_transform = T.UnetDataTransform(which_challenge="singlecoil")

    if "_mc" in challenge:
        dataset = SliceDataset(
            root=data_path,
            transform=data_transform,
            challenge="multicoil",
        )
    else:
        dataset = SliceDataset(
            root=data_path,
            transform=data_transform,
            challenge="singlecoil",
        )
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)

    # run the model
    start_time = time.perf_counter()
    outputs = defaultdict(list)
    model = model.to(device)

    for batch in tqdm(dataloader, desc="Running inference"):
        with torch.no_grad():
            output, slice_num, fname = run_unet_model(batch, model, device)

        outputs[fname].append((slice_num, output))

    # save outputs
    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

    fastmri.save_reconstructions(outputs, output_path / "reconstructions")

    end_time = time.perf_counter()

    print(f"Elapsed time for {len(dataloader)} slices: {end_time-start_time}")
コード例 #2
0
    def _create_data_loader(self,
                            data_transform,
                            data_partition,
                            sample_rate=None):
        sample_rate = sample_rate or self.sample_rate
        dataset = SliceDataset(root=self.data_path /
                               f"{self.challenge}_{data_partition}",
                               transform=data_transform,
                               sample_rate=sample_rate,
                               challenge=self.challenge,
                               mode=data_partition)

        is_train = data_partition == "train"

        # ensure that entire volumes go to the same GPU in the ddp setting
        sampler = None
        if self.use_ddp:
            if is_train:
                sampler = DistributedSampler(dataset)
            else:
                sampler = VolumeSampler(dataset)

        dataloader = DataLoader(
            dataset=dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=False,
            drop_last=is_train,
            sampler=sampler,
        )

        return dataloader
コード例 #3
0
 def prepare_data(self):
     # call dataset for each split one time to make sure the cache is set up on the
     # rank 0 ddp process. if not using cache, don't do this
     if self.use_dataset_cache_file:
         if self.test_path is not None:
             test_path = self.test_path
         else:
             test_path = self.data_path / f"{self.challenge}_test"
         data_paths = [
             self.data_path / f"{self.challenge}_train",
             self.data_path / f"{self.challenge}_val",
             test_path,
         ]
         data_transforms = [
             self.train_transform,
             self.val_transform,
             self.test_transform,
         ]
         for i, (data_path, data_transform) in enumerate(
                 zip(data_paths, data_transforms)):
             sample_rate = self.sample_rate if i == 0 else 1.0
             volume_sample_rate = self.volume_sample_rate if i == 0 else None
             _ = SliceDataset(
                 root=data_path,
                 transform=data_transform,
                 sample_rate=sample_rate,
                 volume_sample_rate=volume_sample_rate,
                 challenge=self.challenge,
                 use_dataset_cache=self.use_dataset_cache_file,
             )
    def _create_data_loader(
        self,
        data_transform: Callable,
        data_partition: str,
        sample_rate: Optional[float] = None,
    ) -> torch.utils.data.DataLoader:
        if data_partition == "train":
            is_train = True
            sample_rate = self.sample_rate if sample_rate is None else sample_rate
        else:
            is_train = False
            sample_rate = 1.0

        # if desired, combine train and val together for the train split
        dataset: Union[SliceDataset, CombinedSliceDataset]
        if is_train and self.combine_train_val:
            data_paths = [
                self.data_path + f"/{self.challenge}_train",
                self.data_path + f"/{self.challenge}_val",
            ]
            data_transforms = [data_transform, data_transform]
            challenges = [self.challenge, self.challenge]
            sample_rates = [sample_rate, sample_rate]
            dataset = CombinedSliceDataset(
                roots=data_paths,
                transforms=data_transforms,
                challenges=challenges,
                sample_rates=sample_rates,
                use_dataset_cache=self.use_dataset_cache_file,
            )
        else:
            if data_partition in ("test",
                                  "challenge") and self.test_path is not None:
                data_path = self.test_path
            else:
                data_path = self.data_path + f"/{self.challenge}_{data_partition}"

            dataset = SliceDataset(
                root=data_path,
                transform=data_transform,
                sample_rate=sample_rate,
                challenge=self.challenge,
                use_dataset_cache=self.use_dataset_cache_file,
            )
        # ensure that entire volumes go to the same GPU in the ddp setting
        sampler = None
        if self.distributed_sampler:
            if is_train:
                sampler = torch.utils.data.DistributedSampler(dataset)
            else:
                sampler = fastmri.data.VolumeSampler(dataset)

        dataloader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            worker_init_fn=worker_init_fn,
            sampler=sampler,
        )
        return dataloader
コード例 #5
0
def run_inference(checkpoint, data_path, output_path):
    varnet = VarNet()
    load_state_dict = torch.load(checkpoint)["state_dict"]
    state_dict = {}
    for k, v in load_state_dict.items():
        if "varnet" in k:
            state_dict[k[len("varnet."):]] = v

    varnet.load_state_dict(state_dict)
    varnet = varnet.eval()

    data_transform = DataTransform()

    dataset = SliceDataset(
        root=data_path,
        transform=data_transform,
        challenge="multicoil",
    )
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)

    start_time = time.perf_counter()
    outputs = defaultdict(list)

    for batch in tqdm(dataloader, desc="Running inference..."):
        masked_kspace, mask, _, fname, slice_num, _, crop_size = batch
        crop_size = crop_size[0]  # always have a batch size of 1 for varnet
        fname = fname[0]  # always have batch size of 1 for varnet

        with torch.no_grad():
            try:
                device = torch.device("cuda")

                output = run_model(masked_kspace, mask, varnet, fname, device)
            except RuntimeError:
                print("running on cpu")
                device = torch.device("cpu")

                output = run_model(masked_kspace, mask, varnet, fname, device)

            output = T.center_crop(output, crop_size)[0]

        outputs[fname].append((slice_num, output))

    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

    fastmri.save_reconstructions(outputs, output_path / "reconstructions")

    end_time = time.perf_counter()

    print(f"elapsed time for {len(dataloader)} slices: {end_time-start_time}")
コード例 #6
0
def run_inference(challenge, state_dict_file, data_path, output_path, mask_func, use_sens_net, device):
    model = VarNet(num_cascades=12, pools=4, chans=18, use_sens_net=use_sens_net, sens_pools=4, sens_chans=8)
    # download the state_dict if we don't have it
    if state_dict_file is None:
        if not Path(MODEL_FNAMES[challenge]).exists():
            url_root = VARNET_FOLDER
            download_model(url_root + MODEL_FNAMES[challenge], MODEL_FNAMES[challenge])

        state_dict_file = MODEL_FNAMES[challenge]

    model.load_state_dict(torch.load(state_dict_file))
    model = model.eval()

    # data loader setup
    if mask_func is None:
        data_transform = T.VarNetDataTransform()
    else:
        data_transform = T.VarNetDataTransform(mask_func=mask_func)

    dataset = SliceDataset(
        root=data_path, transform=data_transform, challenge="multicoil"
    )
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)

    # run the model
    start_time = time.perf_counter()
    outputs = defaultdict(list)
    model = model.to(device)

    for batch in tqdm(dataloader, desc="Running inference"):
        with torch.no_grad():
            output, slice_num, fname = run_varnet_model(batch, model, device)

        outputs[fname].append((slice_num, output))

    # save outputs
    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

    fastmri.save_reconstructions(outputs, output_path / "reconstructions")

    end_time = time.perf_counter()

    print(f"Elapsed time for {len(dataloader)} slices: {end_time - start_time}")
コード例 #7
0
def run_inference(challenge, state_dict_file, data_path, output_path, device,
                  mask, in_chans, out_chans, chans):
    if args.unet_module == "unet":
        model = Unet(in_chans=in_chans,
                     out_chans=out_chans,
                     chans=chans,
                     num_pool_layers=4,
                     drop_prob=0.0)
    elif args.unet_module == "nestedunet":
        model = NestedUnet(in_chans=in_chans,
                           out_chans=out_chans,
                           chans=chans,
                           num_pool_layers=4,
                           drop_prob=0.0)

    pretrained_dict = torch.load(state_dict_file, map_location=device)
    model_dict = model.state_dict()
    if args.fine_tuned_model:
        if 'state_dict' in pretrained_dict.keys():
            model_dict = {
                k: pretrained_dict["state_dict"][f"unet.{k}"]
                for k, _ in model_dict.items()
            }  # load from .ckpt
        elif 'unet' in pretrained_dict.keys():
            model_dict = {
                k: pretrained_dict["unet." + k]
                for k, v in model_dict.items()
            }  # load from .torch
        else:
            model_dict = {
                k: pretrained_dict[k]
                for k, v in model_dict.items()
            }  # load from .pt
    else:
        if args.unet_module == "unet":
            model_dict = {
                k: pretrained_dict["classy_state_dict"]["base_model"]["model"]
                ["trunk"]["_feature_blocks.unetblock." + k]
                for k, _ in model_dict.items()
            }
        elif args.unet_module == "nestedunet":
            model_dict = {
                k: pretrained_dict["classy_state_dict"]["base_model"]["model"]
                ["trunk"]["_feature_blocks.nublock." + k]
                for k, v in model_dict.items()
            }

    model.load_state_dict(model_dict)
    model = model.eval()

    # data loader setup
    if "_mc" in challenge:
        data_transform = T.UnetDataTransform(which_challenge="multicoil",
                                             mask_func=mask)
    else:
        data_transform = T.UnetDataTransform(which_challenge="singlecoil",
                                             mask_func=mask)

    if "_mc" in challenge:
        dataset = SliceDataset(
            root=data_path,
            transform=data_transform,
            challenge="multicoil",
        )
    else:
        dataset = SliceDataset(
            root=data_path,
            transform=data_transform,
            challenge="singlecoil",
        )
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)

    # run the model
    start_time = time.perf_counter()
    outputs = defaultdict(list)
    model = model.to(device)

    for batch in tqdm(dataloader, desc="Running inference"):
        with torch.no_grad():
            output, slice_num, fname = run_unet_model(batch, model, device)

        outputs[fname].append((slice_num, output))

    # save outputs
    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

    fastmri.save_reconstructions(outputs, output_path / "reconstructions")

    end_time = time.perf_counter()

    print(f"Elapsed time for {len(dataloader)} slices: {end_time-start_time}")
コード例 #8
0
ファイル: run_bart.py プロジェクト: tmsincomb/fastMRI-1
        help="Number of processes. Set to 0 to disable multiprocessing.",
    )

    return parser


if __name__ == "__main__":
    args = create_arg_parser().parse_args()

    if args.split in ("train", "val"):
        mask = create_mask_for_mask_type(
            args.mask_type,
            args.center_fractions,
            args.accelerations,
        )
    else:
        mask = None
        args.reg_wt = None

    # need this global for multiprocessing
    dataset = SliceDataset(
        root=args.data_path / f"{args.challenge}_{args.split}",
        transform=DataTransform(split=args.split,
                                mask_func=mask,
                                reg_wt=args.reg_wt),
        challenge=args.challenge,
        sample_rate=args.sample_rate,
    )

    run_bart(args)