예제 #1
0
 def __init__(self, in_chans, out_chans, dropout, decoder_channels, lr,
              lr_step_size, lr_gamma, weight_decay, data_path, batch_size,
              mask_type, center_fractions, accelerations, optim_eps):
     super().__init__()
     self.save_hyperparameters()
     self.in_chans = in_chans
     self.out_chans = out_chans
     self.decoder_channels = decoder_channels
     self.lr = lr
     self.lr_step_size = lr_step_size
     self.lr_gamma = lr_gamma
     self.weight_decay = weight_decay
     self.optim_eps = optim_eps
     self.net = ENet(in_channels=in_chans,
                     out_channels=out_chans,
                     decoder_channels=decoder_channels,
                     dropout=dropout)
     mask = create_mask_for_mask_type(mask_type, center_fractions,
                                      accelerations)
     train_transform = UnetDataTransform('singlecoil',
                                         mask_func=mask,
                                         use_seed=False)
     val_transform = UnetDataTransform('singlecoil', mask_func=mask)
     test_transform = UnetDataTransform('singlecoil')
     self.data_module = FastMriDataModule(data_path=pathlib.Path(data_path),
                                          challenge='singlecoil',
                                          train_transform=train_transform,
                                          val_transform=val_transform,
                                          test_transform=test_transform,
                                          test_split='test',
                                          test_path=None,
                                          sample_rate=1.0,
                                          batch_size=batch_size,
                                          num_workers=4,
                                          distributed_sampler=False)
예제 #2
0
def cli_main(args):
    pl.seed_everything(args.seed)

    # ------------
    # data
    # ------------
    # this creates a k-space mask for transforming input data
    mask = create_mask_for_mask_type(args.mask_type, args.center_fractions,
                                     args.accelerations)
    # use random masks for train transform, fixed masks for val transform
    train_transform = UnetDataTransform(args.challenge,
                                        mask_func=mask,
                                        use_seed=False)
    val_transform = UnetDataTransform(args.challenge, mask_func=mask)
    test_transform = UnetDataTransform(args.challenge)
    # ptl data module - this handles data loaders
    data_module = FastMriDataModule(
        data_path=args.data_path,
        challenge=args.challenge,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        test_split=args.test_split,
        test_path=args.test_path,
        sample_rate=args.sample_rate,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        distributed_sampler=(args.accelerator in ("ddp", "ddp_cpu")),
        proportion=args.proportion,
    )

    # ------------
    # model
    # ------------
    model = UnetModule(
        in_chans=args.in_chans,
        out_chans=args.out_chans,
        chans=args.chans,
        num_pool_layers=args.num_pool_layers,
        drop_prob=args.drop_prob,
        lr=args.lr,
        lr_step_size=args.lr_step_size,
        lr_gamma=args.lr_gamma,
        weight_decay=args.weight_decay,
    )

    # ------------
    # trainer
    # ------------
    trainer = pl.Trainer.from_argparse_args(args)

    # ------------
    # run
    # ------------
    if args.mode == "train":
        trainer.fit(model, datamodule=data_module)
    elif args.mode == "test":
        trainer.test(model, datamodule=data_module)
    else:
        raise ValueError(f"unrecognized mode {args.mode}")
    def __init__(self) -> None:
        data_path = Path.cwd() / "data"
        if data_path.is_dir():
            shutil.rmtree(str(data_path))
        data_path.mkdir(exist_ok=False, parents=True)
        _, _, metadata = create_temp_data(data_path)

        def retrieve_metadata_mock(a: Any, fname: Any) -> Any:
            return metadata[str(fname)]

        # That's a bit flaky, we should be un-doing that after, but there's no obvious place of doing so.
        MonkeyPatch().setattr(SliceDataset, "_retrieve_metadata", retrieve_metadata_mock)

        mask = create_mask_for_mask_type(mask_type_str="equispaced",
                                         center_fractions=[0.08],
                                         accelerations=[4])
        # use random masks for train transform, fixed masks for val transform
        train_transform = VarNetDataTransform(mask_func=mask, use_seed=False)
        val_transform = VarNetDataTransform(mask_func=mask)
        test_transform = VarNetDataTransform()

        FastMriDataModule.__init__(self,
                                   data_path=data_path / "knee_data",
                                   challenge="multicoil",
                                   train_transform=train_transform,
                                   val_transform=val_transform,
                                   test_transform=test_transform)
예제 #4
0
def test_mask_types(mask_type):
    shape_list = ((4, 32, 32, 2), (2, 64, 32, 2), (1, 33, 24, 2))
    center_fraction_list = ([0.08], [0.04], [0.04, 0.08])
    acceleration_list = ([4], [8], [4, 8])
    state = np.random.get_state()

    for shape in shape_list:
        for center_fractions, accelerations in zip(center_fraction_list,
                                                   acceleration_list):
            mask_func = create_mask_for_mask_type(mask_type, center_fractions,
                                                  accelerations)
            expected_mask, expected_num_low_frequencies = mask_func(shape,
                                                                    seed=123)
            x = create_input(shape)
            output, mask, num_low_frequencies = transforms.apply_mask(
                x, mask_func, seed=123)

            assert (state[1] == np.random.get_state()[1]).all()
            assert output.shape == x.shape
            assert mask.shape == expected_mask.shape
            assert np.all(expected_mask.numpy() == mask.numpy())
            assert np.all(
                np.where(mask.numpy() == 0, 0, output.numpy()) ==
                output.numpy())
            assert num_low_frequencies == expected_num_low_frequencies
예제 #5
0
 def val_data_transform(self):
     mask = create_mask_for_mask_type(
         self.mask_type,
         self.center_fractions,
         self.accelerations,
     )
     return DataTransform(self.challenge, mask)
예제 #6
0
    def train_data_transform(self):
        mask = create_mask_for_mask_type(
            self.mask_type,
            self.center_fractions,
            self.accelerations,
        )

        return DataTransform(self.challenge, mask, use_seed=False)
예제 #7
0
def test_unet_trainer(fastmri_mock_dataset, backend, tmp_path, monkeypatch):
    knee_path, _, metadata = fastmri_mock_dataset

    def retrieve_metadata_mock(a, fname):
        return metadata[str(fname)]

    monkeypatch.setattr(SliceDataset, "_retrieve_metadata",
                        retrieve_metadata_mock)

    params = build_unet_args(knee_path, tmp_path, backend)
    params.fast_dev_run = True
    params.backend = backend

    mask = create_mask_for_mask_type(params.mask_type, params.center_fractions,
                                     params.accelerations)
    train_transform = UnetDataTransform(params.challenge,
                                        mask_func=mask,
                                        use_seed=False)
    val_transform = UnetDataTransform(params.challenge, mask_func=mask)
    test_transform = UnetDataTransform(params.challenge)
    data_module = FastMriDataModule(
        data_path=params.data_path,
        challenge=params.challenge,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        test_split=params.test_split,
        sample_rate=params.sample_rate,
        batch_size=params.batch_size,
        num_workers=params.num_workers,
        distributed_sampler=(params.accelerator == "ddp"),
        use_dataset_cache_file=False,
    )

    model = UnetModule(
        in_chans=params.in_chans,
        out_chans=params.out_chans,
        chans=params.chans,
        num_pool_layers=params.num_pool_layers,
        drop_prob=params.drop_prob,
        lr=params.lr,
        lr_step_size=params.lr_step_size,
        lr_gamma=params.lr_gamma,
        weight_decay=params.weight_decay,
    )

    trainer = Trainer.from_argparse_args(params)

    trainer.fit(model, data_module)
예제 #8
0
def get_dataloaders_fastmri(mask_type = 'random',
                            center_fractions  = [0.08],
                            accelerations = [4],
                            challenge = 'singlecoil',
                            batch_size = 8,
                            num_workers = 4,
                            distributed_bool = False,
                            dataset_dir = dataset_dir,
                            mri_dir = 'fastmri/knee/',
                            worker_init_fn = None,
                            include_test = False,
                            **kwargs):
    data_path = Path(os.path.join(dataset_dir, mri_dir))
    
    mask = create_mask_for_mask_type(mask_type_str = mask_type, 
                                     center_fractions = center_fractions, 
                                     accelerations = accelerations )
    

    # use random masks for train transform, fixed masks for val transform
    train_transform = UnetDataTransform(challenge, mask_func=mask, use_seed=False)
    val_transform = UnetDataTransform(challenge, mask_func=mask)
    test_transform = UnetDataTransform(challenge)

    # ptl data module - this handles data loaders
    data_module = FastMriDataModule(
        data_path= data_path,
        challenge= challenge,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        batch_size=batch_size,
        num_workers=num_workers,
        distributed_sampler = distributed_bool
    )



    if include_test:        
        dataloaders = {'train': data_module.train_dataloader() ,
                       'validation': data_module.val_dataloader(), 
                       'test': data_module.test_dataloader()}
    else:
        dataloaders = {'train': data_module.train_dataloader() ,
                       'validation': data_module.val_dataloader()}        

    return dataloaders
예제 #9
0
def init_model(args):
    # initialize model with given args
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f'There are {torch.cuda.device_count()} GPU(s) available.')
        print('Device name:', torch.cuda.get_device_name(0))
    else:
        print('No GPU available, using the CPU instead.')
        device = torch.device("cpu")

    # this creates a k-space mask for transforming input data
    mask = create_mask_for_mask_type(
        args.mask_type, args.center_fractions, args.accelerations
    )
    # use random masks for train transform, fixed masks for val transform
    train_transform = UnetDataTransform('singlecoil', mask_func=mask, use_seed=False)
    val_transform = UnetDataTransform('singlecoil', mask_func=mask)
    test_transform = UnetDataTransform('singlecoil')
    # Initialize Process Group
    dist.init_process_group('gloo', init_method='file:///tmp/somefile', rank=0, world_size=1)
    # define the data loaders
    batch_size = args.batch_size
    # create object for data module
    data_module = FastMriDataModule(
        data_path=args.data_path,
        challenge='singlecoil',
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        test_split='test',
        test_path=args.data_path+'/singlecoil_test',
        sample_rate=1,
        batch_size=batch_size,
        # may can use multiple workers here with linux?
        num_workers=0,
        distributed_sampler="ddp",
    )
    # save data to dataloader
    dataloader_tr = data_module.train_dataloader()
    dataloader_val = data_module.val_dataloader()
    dataloader_test = data_module.test_dataloader()

    return dataloader_tr, dataloader_val, dataloader_test, device
예제 #10
0
def get_fastmri_data_module(azure_dataset_id: str,
                            local_dataset: Optional[Path],
                            sample_rate: Optional[float],
                            test_path: str) -> LightningDataModule:
    """
    Creates a LightningDataModule that consumes data from the FastMRI challenge. The type of challenge
    (single/multicoil) is determined from the name of the dataset in Azure blob storage. The mask type is set to
    equispaced, with 4x acceleration.
    :param azure_dataset_id: The name of the dataset (folder name in blob storage).
    :param local_dataset: The local folder at which the dataset has been mounted or downloaded.
    :param sample_rate: Fraction of slices of the training data split to use. Set to a value <1.0 for rapid prototyping.
    :param test_path: The name of the folder inside the dataset that contains the test data.
    :return: A LightningDataModule object.
    """
    if not azure_dataset_id:
        raise ValueError("The azure_dataset_id argument must be provided.")
    if not local_dataset:
        raise ValueError("The local_dataset argument must be provided.")
    for challenge in ["multicoil", "singlecoil"]:
        if challenge in azure_dataset_id:
            break
    else:
        raise ValueError(
            f"Unable to determine the value for the challenge field for this "
            f"dataset: {azure_dataset_id}")

    mask = create_mask_for_mask_type(mask_type_str="equispaced",
                                     center_fractions=[0.08],
                                     accelerations=[4])
    # use random masks for train transform, fixed masks for val transform
    train_transform = VarNetDataTransform(mask_func=mask, use_seed=False)
    val_transform = VarNetDataTransform(mask_func=mask)
    test_transform = VarNetDataTransform()

    return FastMriDataModule(data_path=local_dataset,
                             test_path=local_dataset / test_path,
                             challenge=challenge,
                             sample_rate=sample_rate,
                             train_transform=train_transform,
                             val_transform=val_transform,
                             test_transform=test_transform)
예제 #11
0
    parser.add_argument(
        "--out_chans",
        default=1,
        type=int,
        help="number of output chanenls to U-Net",
    )
    parser.add_argument(
        "--chans",
        default=32,
        type=int,
        help="number of top-level U-Net channels",
    )
    # unet module arguments
    parser.add_argument(
        "--unet_module",
        default="unet",
        choices=("unet", "nestedunet"),
        type=str,
        help="Unet module to run with",
    )

    args = parser.parse_args()

    mask = create_mask_for_mask_type(args.mask_type, args.center_fractions,
                                     args.accelerations)

    run_inference(args.challenge,
                  args.state_dict_file, args.data_path, args.output_path,
                  torch.device(args.device), mask, args.in_chans,
                  args.out_chans, args.chans)
예제 #12
0
def cli_main(args):
    pl.seed_everything(args.seed)

    # ------------
    # data
    # ------------
    # this creates a k-space mask for transforming input data
    mask = create_mask_for_mask_type(args.mask_type, args.center_fractions,
                                     args.accelerations)
    # use random masks for train transform, fixed masks for val transform
    train_transform = UnetDataTransform(args.challenge,
                                        mask_func=mask,
                                        use_seed=False)
    val_transform = UnetDataTransform(args.challenge, mask_func=mask)
    test_transform = UnetDataTransform(args.challenge)
    # ptl data module - this handles data loaders
    data_module = FastMriDataModule(
        data_path=args.data_path,
        challenge=args.challenge,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        test_split=args.test_split,
        test_path=args.test_path,
        sample_rate=args.sample_rate,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        distributed_sampler=(args.accelerator in ("ddp", "ddp_cpu")),
    )

    # ------------
    # model
    # ------------
    model = None
    if args.unet_module == "unet":
        model = UnetModule(
            in_chans=args.in_chans,
            out_chans=args.out_chans,
            chans=int(args.chans),
            num_pool_layers=args.num_pool_layers,
            drop_prob=args.drop_prob,
            lr=args.lr,
            lr_step_size=args.lr_step_size,
            lr_gamma=args.lr_gamma,
            weight_decay=args.weight_decay,
            optimizer=args.optmizer,
        )
    elif args.unet_module == "nestedunet":
        model = NestedUnetModule(
            in_chans=args.in_chans,
            out_chans=args.out_chans,
            chans=args.chans,
            num_pool_layers=args.num_pool_layers,
            drop_prob=args.drop_prob,
            lr=args.lr,
            lr_step_size=args.lr_step_size,
            lr_gamma=args.lr_gamma,
            weight_decay=args.weight_decay,
            optimizer=args.optmizer,
        )

    if args.device == "cuda" and not torch.cuda.is_available():
        raise ValueError(
            "The requested cuda device isn't available please set --device cpu"
        )

    pretrained_dict = torch.load(args.state_dict_file,
                                 map_location=args.device)
    model_dict = model.unet.state_dict()
    if args.unet_module == "unet":
        model_dict = {
            k: pretrained_dict["classy_state_dict"]["base_model"]["model"]
            ["trunk"]["_feature_blocks.unetblock." + k]
            for k, _ in model_dict.items()
        }
    elif args.unet_module == "nestedunet":
        model_dict = {
            k: pretrained_dict["classy_state_dict"]["base_model"]["model"]
            ["trunk"]["_feature_blocks.nublock." + k]
            for k, v in model_dict.items()
        }

    model.unet.load_state_dict(model_dict)

    # ------------
    # trainer
    # ------------
    trainer = pl.Trainer.from_argparse_args(args)

    # ------------
    # run
    # ------------
    output_filename = f"fine_tuned_{args.unet_module}.torch"
    output_model_filepath = f"{args.output_path}/{output_filename}"
    if args.mode == "train":
        trainer.fit(model, datamodule=data_module)
        print(f"Saving model: {output_model_filepath}")
        torch.save(model.state_dict(), output_model_filepath)
        print("DONE!")
    elif args.mode == "test":
        trainer.test(model, datamodule=data_module)
    else:
        raise ValueError(f"unrecognized mode {args.mode}")
예제 #13
0
device = 'cuda'

​

target = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(kspace[slice], axes=(-2, -1)), axes=(-2, -1)), axes=(-2, -1))

target = target / np.max(np.abs(target))

target = np.sqrt(np.sum(T.center_crop(target, crop_size) ** 2, 0))

​

crop_size = (320, 320)

mask_func = create_mask_for_mask_type(mask_type_str="random", center_fractions=[0.08], accelerations=[4])

​

_kspace = T.to_tensor(kspace)[slice]

masked_kspace, mask = T.apply_mask(_kspace, mask_func)

​

linear_recon = masked_kspace[..., 0] + 1j * masked_kspace[..., 1]

linear_recon = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(linear_recon, axes=(-2, -1)), axes=(-2, -1)),

                               axes=(-2, -1))
예제 #14
0
    def test_data_transform(self):
        mask = create_mask_for_mask_type(
            self.mask_type, self.center_fractions, self.accelerations
        )

        return DataTransform(mask)
예제 #15
0
        default="random",
        type=str,
        help="Type of k-space mask",
    )
    parser.add_argument(
        "--center_fractions",
        nargs="+",
        default=[0.08],
        type=float,
        help="Number of center lines to use in mask",
    )
    parser.add_argument(
        "--accelerations",
        nargs="+",
        default=[4],
        type=int,
        help="Acceleration rates to use for masks",
    )

    args = parser.parse_args()

    run_inference(
        args.challenge,
        args.state_dict_file,
        args.data_path,
        args.output_path,
        create_mask_for_mask_type(args.mask_type, args.center_fractions,
                                  args.accelerations) if args.set == 'val' else None,
        args.use_sens_net,
        torch.device(args.device),
    )