def __init__(self) -> None: data_path = Path.cwd() / "data" if data_path.is_dir(): shutil.rmtree(str(data_path)) data_path.mkdir(exist_ok=False, parents=True) _, _, metadata = create_temp_data(data_path) def retrieve_metadata_mock(a: Any, fname: Any) -> Any: return metadata[str(fname)] # That's a bit flaky, we should be un-doing that after, but there's no obvious place of doing so. MonkeyPatch().setattr(SliceDataset, "_retrieve_metadata", retrieve_metadata_mock) mask = create_mask_for_mask_type(mask_type_str="equispaced", center_fractions=[0.08], accelerations=[4]) # use random masks for train transform, fixed masks for val transform train_transform = VarNetDataTransform(mask_func=mask, use_seed=False) val_transform = VarNetDataTransform(mask_func=mask) test_transform = VarNetDataTransform() FastMriDataModule.__init__(self, data_path=data_path / "knee_data", challenge="multicoil", train_transform=train_transform, val_transform=val_transform, test_transform=test_transform)
def cli_main(args): pl.seed_everything(args.seed) # ------------ # data # ------------ # this creates a k-space mask for transforming input data mask = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations) # use random masks for train transform, fixed masks for val transform train_transform = UnetDataTransform(args.challenge, mask_func=mask, use_seed=False) val_transform = UnetDataTransform(args.challenge, mask_func=mask) test_transform = UnetDataTransform(args.challenge) # ptl data module - this handles data loaders data_module = FastMriDataModule( data_path=args.data_path, challenge=args.challenge, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, test_split=args.test_split, test_path=args.test_path, sample_rate=args.sample_rate, batch_size=args.batch_size, num_workers=args.num_workers, distributed_sampler=(args.accelerator in ("ddp", "ddp_cpu")), proportion=args.proportion, ) # ------------ # model # ------------ model = UnetModule( in_chans=args.in_chans, out_chans=args.out_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, drop_prob=args.drop_prob, lr=args.lr, lr_step_size=args.lr_step_size, lr_gamma=args.lr_gamma, weight_decay=args.weight_decay, ) # ------------ # trainer # ------------ trainer = pl.Trainer.from_argparse_args(args) # ------------ # run # ------------ if args.mode == "train": trainer.fit(model, datamodule=data_module) elif args.mode == "test": trainer.test(model, datamodule=data_module) else: raise ValueError(f"unrecognized mode {args.mode}")
def __init__(self, in_chans, out_chans, dropout, decoder_channels, lr, lr_step_size, lr_gamma, weight_decay, data_path, batch_size, mask_type, center_fractions, accelerations, optim_eps): super().__init__() self.save_hyperparameters() self.in_chans = in_chans self.out_chans = out_chans self.decoder_channels = decoder_channels self.lr = lr self.lr_step_size = lr_step_size self.lr_gamma = lr_gamma self.weight_decay = weight_decay self.optim_eps = optim_eps self.net = ENet(in_channels=in_chans, out_channels=out_chans, decoder_channels=decoder_channels, dropout=dropout) mask = create_mask_for_mask_type(mask_type, center_fractions, accelerations) train_transform = UnetDataTransform('singlecoil', mask_func=mask, use_seed=False) val_transform = UnetDataTransform('singlecoil', mask_func=mask) test_transform = UnetDataTransform('singlecoil') self.data_module = FastMriDataModule(data_path=pathlib.Path(data_path), challenge='singlecoil', train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, test_split='test', test_path=None, sample_rate=1.0, batch_size=batch_size, num_workers=4, distributed_sampler=False)
def test_unet_trainer(fastmri_mock_dataset, backend, tmp_path, monkeypatch): knee_path, _, metadata = fastmri_mock_dataset def retrieve_metadata_mock(a, fname): return metadata[str(fname)] monkeypatch.setattr(SliceDataset, "_retrieve_metadata", retrieve_metadata_mock) params = build_unet_args(knee_path, tmp_path, backend) params.fast_dev_run = True params.backend = backend mask = create_mask_for_mask_type(params.mask_type, params.center_fractions, params.accelerations) train_transform = UnetDataTransform(params.challenge, mask_func=mask, use_seed=False) val_transform = UnetDataTransform(params.challenge, mask_func=mask) test_transform = UnetDataTransform(params.challenge) data_module = FastMriDataModule( data_path=params.data_path, challenge=params.challenge, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, test_split=params.test_split, sample_rate=params.sample_rate, batch_size=params.batch_size, num_workers=params.num_workers, distributed_sampler=(params.accelerator == "ddp"), use_dataset_cache_file=False, ) model = UnetModule( in_chans=params.in_chans, out_chans=params.out_chans, chans=params.chans, num_pool_layers=params.num_pool_layers, drop_prob=params.drop_prob, lr=params.lr, lr_step_size=params.lr_step_size, lr_gamma=params.lr_gamma, weight_decay=params.weight_decay, ) trainer = Trainer.from_argparse_args(params) trainer.fit(model, data_module)
def get_dataloaders_fastmri(mask_type = 'random', center_fractions = [0.08], accelerations = [4], challenge = 'singlecoil', batch_size = 8, num_workers = 4, distributed_bool = False, dataset_dir = dataset_dir, mri_dir = 'fastmri/knee/', worker_init_fn = None, include_test = False, **kwargs): data_path = Path(os.path.join(dataset_dir, mri_dir)) mask = create_mask_for_mask_type(mask_type_str = mask_type, center_fractions = center_fractions, accelerations = accelerations ) # use random masks for train transform, fixed masks for val transform train_transform = UnetDataTransform(challenge, mask_func=mask, use_seed=False) val_transform = UnetDataTransform(challenge, mask_func=mask) test_transform = UnetDataTransform(challenge) # ptl data module - this handles data loaders data_module = FastMriDataModule( data_path= data_path, challenge= challenge, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, batch_size=batch_size, num_workers=num_workers, distributed_sampler = distributed_bool ) if include_test: dataloaders = {'train': data_module.train_dataloader() , 'validation': data_module.val_dataloader(), 'test': data_module.test_dataloader()} else: dataloaders = {'train': data_module.train_dataloader() , 'validation': data_module.val_dataloader()} return dataloaders
def get_fastmri_data_module(azure_dataset_id: str, local_dataset: Optional[Path], sample_rate: Optional[float], test_path: str) -> LightningDataModule: """ Creates a LightningDataModule that consumes data from the FastMRI challenge. The type of challenge (single/multicoil) is determined from the name of the dataset in Azure blob storage. The mask type is set to equispaced, with 4x acceleration. :param azure_dataset_id: The name of the dataset (folder name in blob storage). :param local_dataset: The local folder at which the dataset has been mounted or downloaded. :param sample_rate: Fraction of slices of the training data split to use. Set to a value <1.0 for rapid prototyping. :param test_path: The name of the folder inside the dataset that contains the test data. :return: A LightningDataModule object. """ if not azure_dataset_id: raise ValueError("The azure_dataset_id argument must be provided.") if not local_dataset: raise ValueError("The local_dataset argument must be provided.") for challenge in ["multicoil", "singlecoil"]: if challenge in azure_dataset_id: break else: raise ValueError( f"Unable to determine the value for the challenge field for this " f"dataset: {azure_dataset_id}") mask = create_mask_for_mask_type(mask_type_str="equispaced", center_fractions=[0.08], accelerations=[4]) # use random masks for train transform, fixed masks for val transform train_transform = VarNetDataTransform(mask_func=mask, use_seed=False) val_transform = VarNetDataTransform(mask_func=mask) test_transform = VarNetDataTransform() return FastMriDataModule(data_path=local_dataset, test_path=local_dataset / test_path, challenge=challenge, sample_rate=sample_rate, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform)
def build_args(): parser = ArgumentParser() # basic args path_config = pathlib.Path("../../fastmri_dirs.yaml") backend = "ddp" num_gpus = 2 if backend == "ddp" else 1 batch_size = 1 # set defaults based on optional directory config data_path = fetch_dir("knee_path", path_config) default_root_dir = fetch_dir("log_path", path_config) / "varnet" / "varnet_demo" # client arguments parser.add_argument( "--mode", default="train", choices=("train", "test"), type=str, help="Operation mode", ) # data transform params parser.add_argument( "--mask_type", choices=("random", "equispaced"), default="equispaced", type=str, help="Type of k-space mask", ) parser.add_argument( "--center_fractions", nargs="+", default=[0.08], type=float, help="Number of center lines to use in mask", ) parser.add_argument( "--accelerations", nargs="+", default=[4], type=int, help="Acceleration rates to use for masks", ) # data config parser = FastMriDataModule.add_data_specific_args(parser) parser.set_defaults( data_path=data_path, # path to fastMRI data mask_type="equispaced", # VarNet uses equispaced mask challenge="multicoil", # only multicoil implemented for VarNet batch_size=batch_size, # number of samples per batch test_path=None, # path for test split, overwrites data_path ) # module config parser = VarNetModule.add_model_specific_args(parser) parser.set_defaults( num_cascades=8, # number of unrolled iterations pools=4, # number of pooling layers for U-Net chans=18, # number of top-level channels for U-Net sens_pools=4, # number of pooling layers for sense est. U-Net sens_chans=8, # number of top-level channels for sense est. U-Net lr=0.001, # Adam learning rate lr_step_size=40, # epoch at which to decrease learning rate lr_gamma=0.1, # extent to which to decrease learning rate weight_decay=0.0, # weight regularization strength ) # trainer config parser = pl.Trainer.add_argparse_args(parser) parser.set_defaults( gpus=num_gpus, # number of gpus to use replace_sampler_ddp= False, # this is necessary for volume dispatch during val accelerator=backend, # what distributed version to use seed=42, # random seed deterministic=True, # makes things slower, but deterministic default_root_dir=default_root_dir, # directory for logs and checkpoints max_epochs=50, # max number of epochs ) args = parser.parse_args() # configure checkpointing in checkpoint_dir checkpoint_dir = args.default_root_dir / "checkpoints" if not checkpoint_dir.exists(): checkpoint_dir.mkdir(parents=True) args.checkpoint_callback = pl.callbacks.ModelCheckpoint( filepath=args.default_root_dir / "checkpoints", save_top_k=True, verbose=True, monitor="validation_loss", mode="min", prefix="", ) # set default checkpoint if one exists in our checkpoint directory if args.resume_from_checkpoint is None: ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime) if ckpt_list: args.resume_from_checkpoint = str(ckpt_list[-1]) return args
def build_varnet_args(data_path, logdir, backend): parser = ArgumentParser() num_gpus = 0 batch_size = 1 # data transform params parser.add_argument( "--mask_type", choices=("random", "equispaced"), default="equispaced", type=str, help="Type of k-space mask", ) parser.add_argument( "--center_fractions", nargs="+", default=[0.08], type=float, help="Number of center lines to use in mask", ) parser.add_argument( "--accelerations", nargs="+", default=[4], type=int, help="Acceleration rates to use for masks", ) # data config parser = FastMriDataModule.add_data_specific_args(parser) parser.set_defaults( data_path=data_path, mask_type="equispaced", challenge="multicoil", batch_size=batch_size, ) # module config parser = VarNetModule.add_model_specific_args(parser) parser.set_defaults( num_cascades=4, pools=2, chans=8, sens_pools=2, sens_chans=4, lr=0.001, lr_step_size=40, lr_gamma=0.1, weight_decay=0.0, ) # trainer config parser = Trainer.add_argparse_args(parser) parser.set_defaults( gpus=num_gpus, default_root_dir=logdir, replace_sampler_ddp=(backend != "ddp"), accelerator=backend, ) parser.add_argument("--mode", default="train", type=str) args = parser.parse_args([]) return args
def build_args(): parser = ArgumentParser() # basic args path_config = pathlib.Path("../../fastmri_dirs.yaml") num_gpus = 0 backend = "ddp_cpu" batch_size = 1 if backend == "ddp_cpu" else num_gpus # set defaults based on optional directory config data_path = fetch_dir("knee_path", path_config) default_root_dir = fetch_dir("log_path", path_config) / "unet" / "unet_demo" # client arguments parser.add_argument( "--mode", default="train", choices=("train", "test"), type=str, help="Operation mode", ) # data transform params parser.add_argument( "--mask_type", choices=("random", "equispaced"), default="random", type=str, help="Type of k-space mask", ) parser.add_argument( "--center_fractions", nargs="+", default=[0.08], type=float, help="Number of center lines to use in mask", ) parser.add_argument( "--proportion", default=0.1, type=float, help="Proportion of label data", ) parser.add_argument( "--accelerations", nargs="+", default=[4], type=int, help="Acceleration rates to use for masks", ) # data config with path to fastMRI data and batch size parser = FastMriDataModule.add_data_specific_args(parser) parser.set_defaults(data_path=data_path, batch_size=batch_size, test_path=None) # module config parser = UnetModule.add_model_specific_args(parser) parser.set_defaults( in_chans=1, # number of input channels to U-Net out_chans=1, # number of output chanenls to U-Net chans=32, # number of top-level U-Net channels num_pool_layers=4, # number of U-Net pooling layers drop_prob=0.0, # dropout probability lr=0.001, # RMSProp learning rate lr_step_size=40, # epoch at which to decrease learning rate lr_gamma=0.1, # extent to which to decrease learning rate weight_decay=0.0, # weight decay regularization strength ) # trainer config parser = pl.Trainer.add_argparse_args(parser) parser.set_defaults( gpus=num_gpus, # number of gpus to use replace_sampler_ddp= False, # this is necessary for volume dispatch during val accelerator=backend, # what distributed version to use seed=42, # random seed deterministic=True, # makes things slower, but deterministic default_root_dir=default_root_dir, # directory for logs and checkpoints max_epochs=50, # max number of epochs ) args = parser.parse_args() # configure checkpointing in checkpoint_dir checkpoint_dir = args.default_root_dir / "checkpoints" if not checkpoint_dir.exists(): checkpoint_dir.mkdir(parents=True) args.checkpoint_callback = pl.callbacks.ModelCheckpoint( dirpath=args.default_root_dir / "checkpoints", save_top_k=True, verbose=True, monitor="validation_loss", mode="min", prefix="", ) # set default checkpoint if one exists in our checkpoint directory if args.resume_from_checkpoint is None: ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime) if ckpt_list: args.resume_from_checkpoint = str(ckpt_list[-1]) return args
def build_args(): parser = ArgumentParser() # basic args path_config = pathlib.Path("../../fastmri_dirs.yaml") batch_size = 1 if backend == "ddp" else num_gpus # set defaults based on optional directory config data_path = fetch_dir("knee_path", path_config) default_root_dir = fetch_dir("log_path", path_config) / "nnret" / "nnret_demo" # client arguments parser.add_argument( "--mode", default="train", choices=("train", "test"), type=str, help="Operation mode", ) # data transform params parser.add_argument( "--mask_type", choices=("random", "equispaced"), default="random", type=str, help="Type of k-space mask", ) parser.add_argument( "--center_fractions", nargs="+", default=[0.08], type=float, help="Number of center lines to use in mask", ) parser.add_argument( "--accelerations", nargs="+", default=[4], type=int, help="Acceleration rates to use for masks", ) # data config with path to fastMRI data and batch size parser = FastMriDataModule.add_data_specific_args(parser) parser.set_defaults(data_path=data_path, batch_size=batch_size, test_path=None) # module config parser = NnRetModule.add_model_specific_args(parser) parser.set_defaults( in_chans=1, # number of input channels to NNRET out_chans=1, # number of output chanenls to NNRET chans=32, # number of top-level NNRET channels num_pool_layers=4, # number of NNRET pooling layers drop_prob=0.0, # dropout probability lr=0.001, # RMSProp learning rate lr_step_size=40, # epoch at which to decrease learning rate lr_gamma=0.1, # extent to which to decrease learning rate weight_decay=0.0, # weight decay regularization strength ) # trainer config parser = pl.Trainer.add_argparse_args(parser) parser.set_defaults( gpus=num_gpus, # number of gpus to use replace_sampler_ddp=False, # this is necessary for volume dispatch during val accelerator=backend, # what distributed version to use seed=42, # random seed deterministic=True, # makes things slower, but deterministic default_root_dir=default_root_dir, # directory for logs and checkpoints max_epochs=1, # max number of epochs ) args = parser.parse_args() return args
def cli_main(args): pl.seed_everything(args.seed) # ------------ # data # ------------ # this creates a k-space mask for transforming input data mask = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations) # use random masks for train transform, fixed masks for val transform train_transform = UnetDataTransform(args.challenge, mask_func=mask, use_seed=False) val_transform = UnetDataTransform(args.challenge, mask_func=mask) test_transform = UnetDataTransform(args.challenge) # ptl data module - this handles data loaders data_module = FastMriDataModule( data_path=args.data_path, challenge=args.challenge, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, test_split=args.test_split, test_path=args.test_path, sample_rate=args.sample_rate, batch_size=args.batch_size, num_workers=args.num_workers, distributed_sampler=(args.accelerator in ("ddp", "ddp_cpu")), ) # ------------ # model # ------------ model = None if args.unet_module == "unet": model = UnetModule( in_chans=args.in_chans, out_chans=args.out_chans, chans=int(args.chans), num_pool_layers=args.num_pool_layers, drop_prob=args.drop_prob, lr=args.lr, lr_step_size=args.lr_step_size, lr_gamma=args.lr_gamma, weight_decay=args.weight_decay, optimizer=args.optmizer, ) elif args.unet_module == "nestedunet": model = NestedUnetModule( in_chans=args.in_chans, out_chans=args.out_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, drop_prob=args.drop_prob, lr=args.lr, lr_step_size=args.lr_step_size, lr_gamma=args.lr_gamma, weight_decay=args.weight_decay, optimizer=args.optmizer, ) if args.device == "cuda" and not torch.cuda.is_available(): raise ValueError( "The requested cuda device isn't available please set --device cpu" ) pretrained_dict = torch.load(args.state_dict_file, map_location=args.device) model_dict = model.unet.state_dict() if args.unet_module == "unet": model_dict = { k: pretrained_dict["classy_state_dict"]["base_model"]["model"] ["trunk"]["_feature_blocks.unetblock." + k] for k, _ in model_dict.items() } elif args.unet_module == "nestedunet": model_dict = { k: pretrained_dict["classy_state_dict"]["base_model"]["model"] ["trunk"]["_feature_blocks.nublock." + k] for k, v in model_dict.items() } model.unet.load_state_dict(model_dict) # ------------ # trainer # ------------ trainer = pl.Trainer.from_argparse_args(args) # ------------ # run # ------------ output_filename = f"fine_tuned_{args.unet_module}.torch" output_model_filepath = f"{args.output_path}/{output_filename}" if args.mode == "train": trainer.fit(model, datamodule=data_module) print(f"Saving model: {output_model_filepath}") torch.save(model.state_dict(), output_model_filepath) print("DONE!") elif args.mode == "test": trainer.test(model, datamodule=data_module) else: raise ValueError(f"unrecognized mode {args.mode}")
def build_args(): parser = ArgumentParser() batch_size = 1 # client arguments parser.add_argument( "--mode", default="train", choices=("train", "test"), type=str, help="Operation mode", ) # unet module arguments parser.add_argument( "--unet_module", default="unet", choices=("unet", "nestedunet"), type=str, help="Unet module to run with", ) # data transform params parser.add_argument( "--mask_type", choices=("random", "equispaced"), default="random", type=str, help="Type of k-space mask", ) parser.add_argument( "--center_fractions", nargs="+", default=[0.08], type=float, help="Number of center lines to use in mask", ) parser.add_argument( "--accelerations", nargs="+", default=[4], type=int, help="Acceleration rates to use for masks", ) parser.add_argument( "--device", default="cuda", type=str, help="Model to run", ) parser.add_argument( "--state_dict_file", default=None, type=Path, help="Path to saved state_dict (will download if not provided)", ) parser.add_argument( "--output_path", type=Path, # directory for logs and checkpoints default=Path("./fine_tuning"), help="Path for saving reconstructions", ) # unet specific parser.add_argument( "--in_chans", default=1, type=int, help="number of input channels to U-Net", ) parser.add_argument( "--out_chans", default=1, type=int, help="number of output chanenls to U-Net", ) parser.add_argument( "--chans", default=32, type=int, help="number of top-level U-Net channels", ) # RMSProp parameters parser.add_argument( "--opt_drop_prob", default=0.0, type=float, help="dropout probability", ) parser.add_argument( "--opt_lr", default=0.001, type=float, help="RMSProp learning rate", ) parser.add_argument( "--opt_lr_step_size", default=10, type=int, help="epoch at which to decrease learning rate", ) parser.add_argument( "--opt_lr_gamma", default=0.1, type=float, help="extent to which to decrease learning rate", ) parser.add_argument( "--opt_weight_decay", default=0.0, type=float, help="weight decay regularization strength", ) parser.add_argument( "--opt_optimizer", choices=("RMSprop", "Adam"), default="RMSprop", type=str, help="optimizer (RMSprop, Adam)", ) # data config with path to fastMRI data and batch size parser = FastMriDataModule.add_data_specific_args(parser) parser.set_defaults(data_path="/home/ec2-user/mri", batch_size=batch_size, test_path=None) # trainer config parser = pl.Trainer.add_argparse_args(parser) parser.set_defaults( gpus=0, # number of gpus to use replace_sampler_ddp= False, # this is necessary for volume dispatch during val seed=42, # random seed deterministic=True, # makes things slower, but deterministic max_epochs=50, # max number of epochs unet_module="unet", # "unet" or "nestedunet" ) args = parser.parse_args() # module config if args.unet_module == "unet": parser = UnetModule.add_model_specific_args(parser) parser.set_defaults( num_pool_layers=4, # number of U-Net pooling layers drop_prob=args.opt_drop_prob, # dropout probability lr=args.opt_lr, # RMSProp learning rate lr_step_size=args. opt_lr_step_size, # epoch at which to decrease learning rate lr_gamma=args. opt_lr_gamma, # extent to which to decrease learning rate weight_decay=args. opt_weight_decay, # weight decay regularization strength optmizer=args.opt_optimizer, # optimizer (RMSprop, Adam) accelerator="ddp_cpu" if args.device == "cpu" else "ddp", ) elif args.unet_module == "nestedunet": parser = NestedUnetModule.add_model_specific_args(parser) parser.set_defaults( num_pool_layers=4, # number of U-Net pooling layers drop_prob=args.opt_drop_prob, # dropout probability lr=args.opt_lr, # RMSProp learning rate lr_step_size=args. opt_lr_step_size, # epoch at which to decrease learning rate lr_gamma=args. opt_lr_gamma, # extent to which to decrease learning rate weight_decay=args. opt_weight_decay, # weight decay regularization strength optmizer=args.opt_optimizer, # optimizer (RMSprop, Adam) accelerator="ddp_cpu" if args.device == "cpu" else "ddp", ) args = parser.parse_args() # configure checkpointing in checkpoint_dir checkpoint_dir = args.output_path / "checkpoints" if not checkpoint_dir.exists(): checkpoint_dir.mkdir(parents=True) args.checkpoint_callback = pl.callbacks.ModelCheckpoint( dirpath=args.output_path / "checkpoints", save_top_k=True, verbose=True, monitor="validation_loss", mode="min", prefix="", ) # set default checkpoint if one exists in our checkpoint directory if args.resume_from_checkpoint is None: ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime) if ckpt_list: args.resume_from_checkpoint = str(ckpt_list[-1]) return args