コード例 #1
0
def create_prefetch_datasets(args):
    transform = Prefetch2Device(device=args.device)

    arguments = vars(args)  # Placed here for backward compatibility and convenience.
    args.sample_rate_train = arguments.get('sample_rate_train', arguments.get('sample_rate'))
    args.sample_rate_val = arguments.get('sample_rate_val', arguments.get('sample_rate'))
    args.start_slice_train = arguments.get('start_slice_train', arguments.get('start_slice'))
    args.start_slice_val = arguments.get('start_slice_val', arguments.get('start_slice'))

    # Generating Datasets.
    train_dataset = CustomSliceData(
        root=Path(args.data_root) / f'{args.challenge}_train',
        transform=transform,
        challenge=args.challenge,
        sample_rate=args.sample_rate_train,
        start_slice=args.start_slice_train,
        use_gt=args.use_gt
    )

    val_dataset = CustomSliceData(
        root=Path(args.data_root) / f'{args.challenge}_val',
        transform=transform,
        challenge=args.challenge,
        sample_rate=args.sample_rate_val,
        start_slice=args.start_slice_val,
        use_gt=args.use_gt
    )
    return train_dataset, val_dataset
コード例 #2
0
def create_custom_datasets(args, transform=None):

    transform = Prefetch2Device(
        device=args.device) if transform is None else transform

    # Generating Datasets.
    train_dataset = CustomSliceData(root=Path(args.data_root) /
                                    f'{args.challenge}_train',
                                    transform=transform,
                                    challenge=args.challenge,
                                    sample_rate=args.sample_rate,
                                    start_slice=args.start_slice,
                                    use_gt=False)

    val_dataset = CustomSliceData(root=Path(args.data_root) /
                                  f'{args.challenge}_val',
                                  transform=transform,
                                  challenge=args.challenge,
                                  sample_rate=args.sample_rate,
                                  start_slice=args.start_slice,
                                  use_gt=False)
    return train_dataset, val_dataset
コード例 #3
0
def train_complex(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2**args.num_pool_layers

    if args.random_sampling:
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    data_prefetch = Prefetch2Device(device)

    if args.train_method == 'WS2C':  # Semi-k-space learning.
        weight_func = SemiDistanceWeight(weight_type=args.weight_type)
        input_train_transform = PreProcessWSK(mask_func,
                                              weight_func,
                                              args.challenge,
                                              device,
                                              use_seed=False,
                                              divisor=divisor)
        input_val_transform = PreProcessWSK(mask_func,
                                            weight_func,
                                            args.challenge,
                                            device,
                                            use_seed=True,
                                            divisor=divisor)
        output_transform = WeightedReplacePostProcessSemiK(
            weighted=True, replace=args.replace)

    elif args.train_method == 'WK2C':  # k-space learning.
        weight_func = TiltedDistanceWeight(weight_type=args.weight_type,
                                           y_scale=args.y_scale)
        input_train_transform = PreProcessWK(mask_func,
                                             weight_func,
                                             args.challenge,
                                             device,
                                             use_seed=False,
                                             divisor=divisor)
        input_val_transform = PreProcessWK(mask_func,
                                           weight_func,
                                           args.challenge,
                                           device,
                                           use_seed=True,
                                           divisor=divisor)
        output_transform = WeightedReplacePostProcessK(weighted=True,
                                                       replace=args.replace)
    else:
        raise NotImplementedError('Invalid train method!')

    # DataLoaders
    train_loader, val_loader = create_custom_data_loaders(
        args, transform=data_prefetch)

    losses = dict(cmg_loss=nn.MSELoss(reduction='mean'))

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    # model = UNetModel(
    #     in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers,
    #     num_groups=args.num_groups, use_residual=args.use_residual, pool_type=args.pool_type, use_skip=args.use_skip,
    #     use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp,
    #     use_sa=args.use_sa, sa_kernel_size=args.sa_kernel_size, sa_dilation=args.sa_dilation, use_cap=args.use_cap,
    #     use_cmp=args.use_cmp).to(device)

    model = UNetModelKSSE(in_chans=data_chans,
                          out_chans=data_chans,
                          chans=args.chans,
                          num_pool_layers=args.num_pool_layers,
                          num_groups=args.num_groups,
                          use_residual=args.use_residual,
                          pool_type=args.pool_type,
                          use_skip=args.use_skip,
                          min_ext_size=args.min_ext_size,
                          max_ext_size=args.max_ext_size,
                          ext_mode=args.ext_mode,
                          use_ca=args.use_ca,
                          reduction=args.reduction,
                          use_gap=args.use_gap,
                          use_gmp=args.use_gmp,
                          use_sa=args.use_sa,
                          sa_kernel_size=args.sa_kernel_size,
                          sa_dilation=args.sa_dilation,
                          use_cap=args.use_cap,
                          use_cmp=args.use_cmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerCOMPLEX(args, model, optimizer, train_loader,
                                  val_loader, input_train_transform,
                                  input_val_transform, output_transform,
                                  losses, scheduler)

    trainer.train_model()
コード例 #4
0
def main(args):
    from models.edsr_unet import UNet  # Moving import line here to reduce confusion.
    from data.input_transforms import PreProcessIMG, Prefetch2Device
    from data.rss_inputs import PreProcessRSS
    from eval.input_test_transform import PreProcessTestIMG, PreProcessValIMG, PreProcessTestRSS
    from eval.output_test_transforms import PostProcessTestIMG, PostProcessTestRSS
    from train.subsample import RandomMaskFunc

    # Selecting device
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
    else:
        device = torch.device('cpu')

    print(f'Device {device} has been selected.')

    # data_chans = 1 if args.challenge == 'singlecoil' else 15
    model = UNet(in_chans=15,
                 out_chans=1,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 res_scale=args.res_scale,
                 use_residual=args.use_residual,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    dataset = CustomSliceData(root=args.data_root,
                              transform=Prefetch2Device(device),
                              challenge=args.challenge,
                              sample_rate=1,
                              start_slice=0,
                              use_gt=False)

    data_loader = DataLoader(dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=args.num_workers,
                             collate_fn=temp_collate_fn,
                             pin_memory=False)

    mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    # divisor = 2 ** args.num_pool_layers  # For UNet size fitting.

    # This is for the validation set, not the test set. The test set requires a different pre-processing function.
    if Path(args.data_root).name.endswith('val'):
        # pre_processing = PreProcessIMG(mask_func=mask_func, challenge=args.challenge, device=device,
        #                                augment_data=False, use_seed=True, crop_center=True)
        pre_processing = PreProcessRSS(mask_func=mask_func,
                                       challenge=args.challenge,
                                       device=device,
                                       augment_data=False,
                                       use_seed=True)
    elif Path(args.data_root).name.endswith('test_v2'):
        pre_processing = PreProcessTestRSS(challenge=args.challenge,
                                           device=device)
        # pre_processing = PreProcessTestIMG(challenge=args.challenge, device=device, crop_center=True)
    else:
        raise NotImplementedError(
            'Invalid data root. If using the original test set, please change to test_v2.'
        )

    # post_processing = PostProcessTestIMG(challenge=args.challenge)
    post_processing = PostProcessTestRSS(challenge=args.challenge,
                                         residual_rss=args.residual_rss)

    # Single acceleration, single model version.
    evaluator = ModelEvaluator(model, args.checkpoint_path, args.challenge,
                               data_loader, pre_processing, post_processing,
                               args.data_root, args.out_dir, device)

    # evaluator = MultiAccelerationModelEvaluator(
    #     model=model, checkpoint_path_4=args.checkpoint_path_4, checkpoint_path_8=args.checkpoint_path_8,
    #     challenge=args.challenge, data_loader=data_loader, pre_processing=pre_processing,
    #     post_processing=post_processing, data_root=args.data_root, out_dir=args.out_dir, device=device
    # )

    evaluator.create_and_save_reconstructions()
コード例 #5
0
def train_k2i(args):

    # Maybe move this to args later.
    train_method = 'W2I'  # Weighted K-space to real-valued image.

    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2**args.num_pool_layers

    if args.random_sampling:
        mask_func = MaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    # This is optimized for SSD storage.
    # Sending to device should be inside the input transform for optimal performance on HDD.
    data_prefetch = Prefetch2Device(device)

    input_train_transform = WeightedPreProcessK(mask_func,
                                                args.challenge,
                                                device,
                                                use_seed=False,
                                                divisor=divisor)
    input_val_transform = WeightedPreProcessK(mask_func,
                                              args.challenge,
                                              device,
                                              use_seed=True,
                                              divisor=divisor)

    # DataLoaders
    train_loader, val_loader = create_custom_data_loaders(
        args, transform=data_prefetch)

    losses = dict(img_loss=nn.L1Loss(reduction='mean')
                  # img_loss=L1CSSIM7(reduction='mean', alpha=args.alpha)
                  )

    output_transform = WeightedReplacePostProcessK()

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    model = UNetSkipGN(in_chans=data_chans,
                       out_chans=data_chans,
                       chans=args.chans,
                       num_pool_layers=args.num_pool_layers,
                       num_groups=args.num_groups,
                       pool_type=args.pool_type,
                       use_skip=args.use_skip,
                       use_att=args.use_att,
                       reduction=args.reduction,
                       use_gap=args.use_gap,
                       use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerK2I(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_transform, losses, scheduler)

    trainer.train_model()
コード例 #6
0
def train_img(args):

    # Maybe move this to args later.
    train_method = 'K2CI'

    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2**args.num_pool_layers

    mask_func = MaskFunc(args.center_fractions, args.accelerations)

    data_prefetch = Prefetch2Device(device)

    input_train_transform = TrainPreProcessK(mask_func,
                                             args.challenge,
                                             args.device,
                                             use_seed=False,
                                             divisor=divisor)
    input_val_transform = TrainPreProcessK(mask_func,
                                           args.challenge,
                                           args.device,
                                           use_seed=True,
                                           divisor=divisor)

    # DataLoaders
    train_loader, val_loader = create_custom_data_loaders(
        args, transform=data_prefetch)

    losses = dict(
        cmg_loss=nn.MSELoss(reduction='mean'),
        # img_loss=L1CSSIM7(reduction='mean', alpha=0.5)
        img_loss=CSSIM(filter_size=7, reduction='mean'))

    output_transform = OutputReplaceTransformK()

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    model = UnetASE(in_chans=data_chans,
                    out_chans=data_chans,
                    ext_chans=args.chans,
                    chans=args.chans,
                    num_pool_layers=args.num_pool_layers,
                    min_ext_size=args.min_ext_size,
                    max_ext_size=args.max_ext_size,
                    use_ext_bias=args.use_ext_bias,
                    use_att=False).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.lr_red_epoch,
                                          gamma=args.lr_red_rate)

    trainer = ModelTrainerK2CI(args, model, optimizer, train_loader,
                               val_loader, input_train_transform,
                               input_val_transform, output_transform, losses,
                               scheduler)

    trainer.train_model()