def build_args(): # ------------------------ # TRAINING ARGUMENTS # ------------------------ path_config = pathlib.Path.cwd() / ".." / ".." / "fastmri_dirs.yaml" knee_path = fetch_dir("knee_path", path_config) logdir = fetch_dir("log_path", path_config) / "unet" / "unet_knee_sc_leaderboard" parent_parser = ArgumentParser(add_help=False) parser = UnetModule.add_model_specific_args(parent_parser) parser = Trainer.add_argparse_args(parser) num_gpus = 32 # this was the number of GPUs for training backend = "ddp" batch_size = 1 if backend == "ddp" else num_gpus # module config config config = dict( in_chans=1, out_chans=1, chans=256, num_pool_layers=4, drop_prob=0.0, mask_type="random", center_fractions=[0.08, 0.04], accelerations=[4, 8], lr=0.001, lr_step_size=40, lr_gamma=0.1, weight_decay=0.0, data_path=knee_path, challenge="singlecoil", exp_dir=logdir, exp_name="unet_leaderboard", test_split="test", batch_size=batch_size, ) parser.set_defaults(**config) # trainer config parser.set_defaults( default_root_dir=logdir, replace_sampler_ddp=(backend != "ddp"), distributed_backend=backend, gpus=num_gpus, max_epochs=50, seed=42, deterministic=True, ) parser.add_argument("--mode", default="train", type=str) args = parser.parse_args() return args
def build_args(): # ------------------------ # TRAINING ARGUMENTS # ------------------------ path_config = pathlib.Path.cwd() / ".." / ".." / "fastmri_dirs.yaml" knee_path = fetch_dir("knee_path", path_config) brain_path = fetch_dir("brain_path", path_config) logdir = fetch_dir("log_path", path_config) / "unet" / "overfit-on-batch" parent_parser = ArgumentParser(add_help=False) parser = UnetModule.add_model_specific_args(parent_parser) parser = Trainer.add_argparse_args(parser) num_gpus = 1 backend = "ddp" batch_size = 1 if backend == "ddp" else num_gpus # module config config = dict( in_chans=4, out_chans=4, chans=32, num_pool_layers=4, drop_prob=0.0, mask_type="equispaced", center_fractions=[0.08, 0.04], accelerations=[4, 8], lr=0.001, lr_step_size=40, lr_gamma=0.1, weight_decay=0.0, data_path=brain_path, challenge="multicoil", exp_dir=logdir, exp_name="unet_demo", test_split="test", batch_size=batch_size, ) parser.set_defaults(**config) # trainer config parser.set_defaults(gpus=num_gpus, default_root_dir=logdir, replace_sampler_ddp=(backend != "ddp"), distributed_backend=backend, seed=42, deterministic=True, fast_dev_run=True) parser.add_argument("--mode", default="train", type=str) args = parser.parse_args() return args
def main(args): """Main training routine.""" # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ seed_everything(args.seed) model = UnetModule(**vars(args)) # ------------------------ # 2 INIT TRAINER # ------------------------ trainer = Trainer.from_argparse_args(args) # ------------------------ # 3 START TRAINING OR TEST # ------------------------ if args.mode == "train": trainer.fit(model) elif args.mode == "test": assert args.resume_from_checkpoint is not None trainer.test(model) else: raise ValueError(f"unrecognized mode {args.mode}")