def cli_main(args): pl.seed_everything(args.seed) # ------------ # data # ------------ # this creates a k-space mask for transforming input data mask = create_mask_for_mask_type( args.mask_type, args.center_fractions, args.accelerations ) # use random masks for train transform, fixed masks for val transform train_transform = VarNetDataTransform(mask_func=mask, use_seed=False) val_transform = VarNetDataTransform(mask_func=mask) test_transform = VarNetDataTransform() # ptl data module - this handles data loaders data_module = FastMriDataModule( data_path=args.data_path, challenge=args.challenge, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, combine_train_val=True, test_split=args.test_split, test_path=args.test_path, sample_rate=args.sample_rate, batch_size=args.batch_size, num_workers=args.num_workers, distributed_sampler=(args.accelerator == "ddp"), ) # ------------ # model # ------------ model = VarNetModule( num_cascades=args.num_cascades, pools=args.pools, chans=args.chans, sens_pools=args.sens_pools, sens_chans=args.sens_chans, lr=args.lr, lr_step_size=args.lr_step_size, lr_gamma=args.lr_gamma, weight_decay=args.weight_decay, ) # ------------ # trainer # ------------ trainer = pl.Trainer.from_argparse_args(args) # ------------ # run # ------------ if args.mode == "train": trainer.fit(model, datamodule=data_module) elif args.mode == "test": trainer.test(model, datamodule=data_module) else: raise ValueError(f"unrecognized mode {args.mode}")
def test_varnet_trainer(fastmri_mock_dataset, backend, tmp_path, monkeypatch): knee_path, _, metadata = fastmri_mock_dataset def retrieve_metadata_mock(a, fname): return metadata[str(fname)] monkeypatch.setattr(SliceDataset, "_retrieve_metadata", retrieve_metadata_mock) params = build_varnet_args(knee_path, tmp_path, backend) params.fast_dev_run = True params.backend = backend mask = create_mask_for_mask_type(params.mask_type, params.center_fractions, params.accelerations) train_transform = VarNetDataTransform(mask_func=mask, use_seed=False) val_transform = VarNetDataTransform(mask_func=mask) test_transform = VarNetDataTransform() data_module = FastMriDataModule( data_path=params.data_path, challenge=params.challenge, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, test_split=params.test_split, sample_rate=params.sample_rate, batch_size=params.batch_size, num_workers=params.num_workers, distributed_sampler=(params.accelerator == "ddp"), use_dataset_cache_file=False, ) model = VarNetModule( num_cascades=params.num_cascades, pools=params.pools, chans=params.chans, sens_pools=params.sens_pools, sens_chans=params.sens_chans, lr=params.lr, lr_step_size=params.lr_step_size, lr_gamma=params.lr_gamma, weight_decay=params.weight_decay, ) trainer = Trainer.from_argparse_args(params) trainer.fit(model, data_module)
def build_args(): parser = ArgumentParser() # basic args path_config = pathlib.Path("../../fastmri_dirs.yaml") backend = "ddp" num_gpus = 2 if backend == "ddp" else 1 batch_size = 1 # set defaults based on optional directory config data_path = fetch_dir("knee_path", path_config) default_root_dir = fetch_dir("log_path", path_config) / "varnet" / "varnet_demo" # client arguments parser.add_argument( "--mode", default="train", choices=("train", "test"), type=str, help="Operation mode", ) # data transform params parser.add_argument( "--mask_type", choices=("random", "equispaced"), default="equispaced", type=str, help="Type of k-space mask", ) parser.add_argument( "--center_fractions", nargs="+", default=[0.08], type=float, help="Number of center lines to use in mask", ) parser.add_argument( "--accelerations", nargs="+", default=[4], type=int, help="Acceleration rates to use for masks", ) # data config parser = FastMriDataModule.add_data_specific_args(parser) parser.set_defaults( data_path=data_path, # path to fastMRI data mask_type="equispaced", # VarNet uses equispaced mask challenge="multicoil", # only multicoil implemented for VarNet batch_size=batch_size, # number of samples per batch test_path=None, # path for test split, overwrites data_path ) # module config parser = VarNetModule.add_model_specific_args(parser) parser.set_defaults( num_cascades=8, # number of unrolled iterations pools=4, # number of pooling layers for U-Net chans=18, # number of top-level channels for U-Net sens_pools=4, # number of pooling layers for sense est. U-Net sens_chans=8, # number of top-level channels for sense est. U-Net lr=0.001, # Adam learning rate lr_step_size=40, # epoch at which to decrease learning rate lr_gamma=0.1, # extent to which to decrease learning rate weight_decay=0.0, # weight regularization strength ) # trainer config parser = pl.Trainer.add_argparse_args(parser) parser.set_defaults( gpus=num_gpus, # number of gpus to use replace_sampler_ddp= False, # this is necessary for volume dispatch during val accelerator=backend, # what distributed version to use seed=42, # random seed deterministic=True, # makes things slower, but deterministic default_root_dir=default_root_dir, # directory for logs and checkpoints max_epochs=50, # max number of epochs ) args = parser.parse_args() # configure checkpointing in checkpoint_dir checkpoint_dir = args.default_root_dir / "checkpoints" if not checkpoint_dir.exists(): checkpoint_dir.mkdir(parents=True) args.checkpoint_callback = pl.callbacks.ModelCheckpoint( filepath=args.default_root_dir / "checkpoints", save_top_k=True, verbose=True, monitor="validation_loss", mode="min", prefix="", ) # set default checkpoint if one exists in our checkpoint directory if args.resume_from_checkpoint is None: ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime) if ckpt_list: args.resume_from_checkpoint = str(ckpt_list[-1]) return args
def build_varnet_args(data_path, logdir, backend): parser = ArgumentParser() num_gpus = 0 batch_size = 1 # data transform params parser.add_argument( "--mask_type", choices=("random", "equispaced"), default="equispaced", type=str, help="Type of k-space mask", ) parser.add_argument( "--center_fractions", nargs="+", default=[0.08], type=float, help="Number of center lines to use in mask", ) parser.add_argument( "--accelerations", nargs="+", default=[4], type=int, help="Acceleration rates to use for masks", ) # data config parser = FastMriDataModule.add_data_specific_args(parser) parser.set_defaults( data_path=data_path, mask_type="equispaced", challenge="multicoil", batch_size=batch_size, ) # module config parser = VarNetModule.add_model_specific_args(parser) parser.set_defaults( num_cascades=4, pools=2, chans=8, sens_pools=2, sens_chans=4, lr=0.001, lr_step_size=40, lr_gamma=0.1, weight_decay=0.0, ) # trainer config parser = Trainer.add_argparse_args(parser) parser.set_defaults( gpus=num_gpus, default_root_dir=logdir, replace_sampler_ddp=(backend != "ddp"), accelerator=backend, ) parser.add_argument("--mode", default="train", type=str) args = parser.parse_args([]) return args