Beispiel #1
0
def train_k2i(args):

    # Maybe move this to args later.
    train_method = 'W2I'  # Weighted K-space to real-valued image.

    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2**args.num_pool_layers

    if args.random_sampling:
        mask_func = MaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    # This is optimized for SSD storage.
    # Sending to device should be inside the input transform for optimal performance on HDD.
    data_prefetch = Prefetch2Device(device)

    input_train_transform = WeightedPreProcessK(mask_func,
                                                args.challenge,
                                                device,
                                                use_seed=False,
                                                divisor=divisor)
    input_val_transform = WeightedPreProcessK(mask_func,
                                              args.challenge,
                                              device,
                                              use_seed=True,
                                              divisor=divisor)

    # DataLoaders
    train_loader, val_loader = create_custom_data_loaders(
        args, transform=data_prefetch)

    losses = dict(img_loss=nn.L1Loss(reduction='mean')
                  # img_loss=L1CSSIM7(reduction='mean', alpha=args.alpha)
                  )

    output_transform = WeightedReplacePostProcessK()

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    model = UNetSkipGN(in_chans=data_chans,
                       out_chans=data_chans,
                       chans=args.chans,
                       num_pool_layers=args.num_pool_layers,
                       num_groups=args.num_groups,
                       pool_type=args.pool_type,
                       use_skip=args.use_skip,
                       use_att=args.use_att,
                       reduction=args.reduction,
                       use_gap=args.use_gap,
                       use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerK2I(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_transform, losses, scheduler)

    trainer.train_model()
Beispiel #2
0
def train_complex(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2**args.num_pool_layers

    if args.random_sampling:
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    data_prefetch = Prefetch2Device(device)

    if args.train_method == 'WS2C':  # Semi-k-space learning.
        weight_func = SemiDistanceWeight(weight_type=args.weight_type)
        input_train_transform = PreProcessWSK(mask_func,
                                              weight_func,
                                              args.challenge,
                                              device,
                                              use_seed=False,
                                              divisor=divisor)
        input_val_transform = PreProcessWSK(mask_func,
                                            weight_func,
                                            args.challenge,
                                            device,
                                            use_seed=True,
                                            divisor=divisor)
        output_transform = WeightedReplacePostProcessSemiK(
            weighted=True, replace=args.replace)

    elif args.train_method == 'WK2C':  # k-space learning.
        weight_func = TiltedDistanceWeight(weight_type=args.weight_type,
                                           y_scale=args.y_scale)
        input_train_transform = PreProcessWK(mask_func,
                                             weight_func,
                                             args.challenge,
                                             device,
                                             use_seed=False,
                                             divisor=divisor)
        input_val_transform = PreProcessWK(mask_func,
                                           weight_func,
                                           args.challenge,
                                           device,
                                           use_seed=True,
                                           divisor=divisor)
        output_transform = WeightedReplacePostProcessK(weighted=True,
                                                       replace=args.replace)
    else:
        raise NotImplementedError('Invalid train method!')

    # DataLoaders
    train_loader, val_loader = create_custom_data_loaders(
        args, transform=data_prefetch)

    losses = dict(cmg_loss=nn.MSELoss(reduction='mean'))

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    # model = UNetModel(
    #     in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers,
    #     num_groups=args.num_groups, use_residual=args.use_residual, pool_type=args.pool_type, use_skip=args.use_skip,
    #     use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp,
    #     use_sa=args.use_sa, sa_kernel_size=args.sa_kernel_size, sa_dilation=args.sa_dilation, use_cap=args.use_cap,
    #     use_cmp=args.use_cmp).to(device)

    model = UNetModelKSSE(in_chans=data_chans,
                          out_chans=data_chans,
                          chans=args.chans,
                          num_pool_layers=args.num_pool_layers,
                          num_groups=args.num_groups,
                          use_residual=args.use_residual,
                          pool_type=args.pool_type,
                          use_skip=args.use_skip,
                          min_ext_size=args.min_ext_size,
                          max_ext_size=args.max_ext_size,
                          ext_mode=args.ext_mode,
                          use_ca=args.use_ca,
                          reduction=args.reduction,
                          use_gap=args.use_gap,
                          use_gmp=args.use_gmp,
                          use_sa=args.use_sa,
                          sa_kernel_size=args.sa_kernel_size,
                          sa_dilation=args.sa_dilation,
                          use_cap=args.use_cap,
                          use_cmp=args.use_cmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerCOMPLEX(args, model, optimizer, train_loader,
                                  val_loader, input_train_transform,
                                  input_val_transform, output_transform,
                                  losses, scheduler)

    trainer.train_model()
Beispiel #3
0
def train_img(args):

    # Maybe move this to args later.
    train_method = 'K2CI'

    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2**args.num_pool_layers

    mask_func = MaskFunc(args.center_fractions, args.accelerations)

    data_prefetch = Prefetch2Device(device)

    input_train_transform = TrainPreProcessK(mask_func,
                                             args.challenge,
                                             args.device,
                                             use_seed=False,
                                             divisor=divisor)
    input_val_transform = TrainPreProcessK(mask_func,
                                           args.challenge,
                                           args.device,
                                           use_seed=True,
                                           divisor=divisor)

    # DataLoaders
    train_loader, val_loader = create_custom_data_loaders(
        args, transform=data_prefetch)

    losses = dict(
        cmg_loss=nn.MSELoss(reduction='mean'),
        # img_loss=L1CSSIM7(reduction='mean', alpha=0.5)
        img_loss=CSSIM(filter_size=7, reduction='mean'))

    output_transform = OutputReplaceTransformK()

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    model = UnetASE(in_chans=data_chans,
                    out_chans=data_chans,
                    ext_chans=args.chans,
                    chans=args.chans,
                    num_pool_layers=args.num_pool_layers,
                    min_ext_size=args.min_ext_size,
                    max_ext_size=args.max_ext_size,
                    use_ext_bias=args.use_ext_bias,
                    use_att=False).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.lr_red_epoch,
                                          gamma=args.lr_red_rate)

    trainer = ModelTrainerK2CI(args, model, optimizer, train_loader,
                               val_loader, input_train_transform,
                               input_val_transform, output_transform, losses,
                               scheduler)

    trainer.train_model()