def train_k2i(args): # Maybe move this to args later. train_method = 'W2I' # Weighted K-space to real-valued image. # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # Input transforms. These are on a per-slice basis. # UNET architecture requires that all inputs be dividable by some power of 2. divisor = 2**args.num_pool_layers if args.random_sampling: mask_func = MaskFunc(args.center_fractions, args.accelerations) else: mask_func = UniformMaskFunc(args.center_fractions, args.accelerations) # This is optimized for SSD storage. # Sending to device should be inside the input transform for optimal performance on HDD. data_prefetch = Prefetch2Device(device) input_train_transform = WeightedPreProcessK(mask_func, args.challenge, device, use_seed=False, divisor=divisor) input_val_transform = WeightedPreProcessK(mask_func, args.challenge, device, use_seed=True, divisor=divisor) # DataLoaders train_loader, val_loader = create_custom_data_loaders( args, transform=data_prefetch) losses = dict(img_loss=nn.L1Loss(reduction='mean') # img_loss=L1CSSIM7(reduction='mean', alpha=args.alpha) ) output_transform = WeightedReplacePostProcessK() data_chans = 2 if args.challenge == 'singlecoil' else 30 # Multicoil has 15 coils with 2 for real/imag model = UNetSkipGN(in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, num_groups=args.num_groups, pool_type=args.pool_type, use_skip=args.use_skip, use_att=args.use_att, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = ModelTrainerK2I(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_transform, losses, scheduler) trainer.train_model()
def train_complex(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # Input transforms. These are on a per-slice basis. # UNET architecture requires that all inputs be dividable by some power of 2. divisor = 2**args.num_pool_layers if args.random_sampling: mask_func = RandomMaskFunc(args.center_fractions, args.accelerations) else: mask_func = UniformMaskFunc(args.center_fractions, args.accelerations) data_prefetch = Prefetch2Device(device) if args.train_method == 'WS2C': # Semi-k-space learning. weight_func = SemiDistanceWeight(weight_type=args.weight_type) input_train_transform = PreProcessWSK(mask_func, weight_func, args.challenge, device, use_seed=False, divisor=divisor) input_val_transform = PreProcessWSK(mask_func, weight_func, args.challenge, device, use_seed=True, divisor=divisor) output_transform = WeightedReplacePostProcessSemiK( weighted=True, replace=args.replace) elif args.train_method == 'WK2C': # k-space learning. weight_func = TiltedDistanceWeight(weight_type=args.weight_type, y_scale=args.y_scale) input_train_transform = PreProcessWK(mask_func, weight_func, args.challenge, device, use_seed=False, divisor=divisor) input_val_transform = PreProcessWK(mask_func, weight_func, args.challenge, device, use_seed=True, divisor=divisor) output_transform = WeightedReplacePostProcessK(weighted=True, replace=args.replace) else: raise NotImplementedError('Invalid train method!') # DataLoaders train_loader, val_loader = create_custom_data_loaders( args, transform=data_prefetch) losses = dict(cmg_loss=nn.MSELoss(reduction='mean')) data_chans = 2 if args.challenge == 'singlecoil' else 30 # Multicoil has 15 coils with 2 for real/imag # model = UNetModel( # in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, # num_groups=args.num_groups, use_residual=args.use_residual, pool_type=args.pool_type, use_skip=args.use_skip, # use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp, # use_sa=args.use_sa, sa_kernel_size=args.sa_kernel_size, sa_dilation=args.sa_dilation, use_cap=args.use_cap, # use_cmp=args.use_cmp).to(device) model = UNetModelKSSE(in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, num_groups=args.num_groups, use_residual=args.use_residual, pool_type=args.pool_type, use_skip=args.use_skip, min_ext_size=args.min_ext_size, max_ext_size=args.max_ext_size, ext_mode=args.ext_mode, use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp, use_sa=args.use_sa, sa_kernel_size=args.sa_kernel_size, sa_dilation=args.sa_dilation, use_cap=args.use_cap, use_cmp=args.use_cmp).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = ModelTrainerCOMPLEX(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_transform, losses, scheduler) trainer.train_model()
def train_img(args): # Maybe move this to args later. train_method = 'K2CI' # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # Input transforms. These are on a per-slice basis. # UNET architecture requires that all inputs be dividable by some power of 2. divisor = 2**args.num_pool_layers mask_func = MaskFunc(args.center_fractions, args.accelerations) data_prefetch = Prefetch2Device(device) input_train_transform = TrainPreProcessK(mask_func, args.challenge, args.device, use_seed=False, divisor=divisor) input_val_transform = TrainPreProcessK(mask_func, args.challenge, args.device, use_seed=True, divisor=divisor) # DataLoaders train_loader, val_loader = create_custom_data_loaders( args, transform=data_prefetch) losses = dict( cmg_loss=nn.MSELoss(reduction='mean'), # img_loss=L1CSSIM7(reduction='mean', alpha=0.5) img_loss=CSSIM(filter_size=7, reduction='mean')) output_transform = OutputReplaceTransformK() data_chans = 2 if args.challenge == 'singlecoil' else 30 # Multicoil has 15 coils with 2 for real/imag model = UnetASE(in_chans=data_chans, out_chans=data_chans, ext_chans=args.chans, chans=args.chans, num_pool_layers=args.num_pool_layers, min_ext_size=args.min_ext_size, max_ext_size=args.max_ext_size, use_ext_bias=args.use_ext_bias, use_att=False).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_red_epoch, gamma=args.lr_red_rate) trainer = ModelTrainerK2CI(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_transform, losses, scheduler) trainer.train_model()