Example #1
0
def train_cmg_to_img(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    if args.random_sampling:
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    input_train_transform = PreProcessCMG(mask_func,
                                          args.challenge,
                                          device,
                                          augment_data=args.augment_data,
                                          use_seed=False,
                                          crop_center=args.crop_center)
    input_val_transform = PreProcessCMG(mask_func,
                                        args.challenge,
                                        device,
                                        augment_data=False,
                                        use_seed=True,
                                        crop_center=args.crop_center)

    output_train_transform = PostProcessCMG(challenge=args.challenge,
                                            residual_acs=args.residual_acs)
    output_val_transform = PostProcessCMG(challenge=args.challenge,
                                          residual_acs=args.residual_acs)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        img_loss=LogSSIMLoss(filter_size=7).to(device)
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        # img_loss=nn.L1Loss()
    )

    # model = UNet(
    #     in_chans=30, out_chans=30, chans=args.chans, num_pool_layers=args.num_pool_layers, num_groups=args.num_groups,
    #     negative_slope=args.negative_slope, use_residual=args.use_residual, interp_mode=args.interp_mode,
    #     use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device)

    model = UNet(in_chans=30,
                 out_chans=30,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 use_residual=args.use_residual,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerIMG(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning('Closing summary writer due to KeyboardInterrupt.')
Example #2
0
def train_xnet(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    arguments = vars(
        args)  # Placed here for backward compatibility and convenience.
    args.center_fractions_train = arguments.get(
        'center_fractions_train', arguments.get('center_fractions'))
    args.center_fractions_val = arguments.get(
        'center_fractions_val', arguments.get('center_fractions'))
    args.accelerations_train = arguments.get('accelerations_train',
                                             arguments.get('accelerations'))
    args.accelerations_val = arguments.get('accelerations_val',
                                           arguments.get('accelerations'))

    if args.random_sampling:
        train_mask_func = RandomMaskFunc(args.center_fractions_train,
                                         args.accelerations_train)
        val_mask_func = RandomMaskFunc(args.center_fractions_val,
                                       args.accelerations_val)
    else:
        train_mask_func = UniformMaskFunc(args.center_fractions_train,
                                          args.accelerations_train)
        val_mask_func = UniformMaskFunc(args.center_fractions_val,
                                        args.accelerations_val)

    input_train_transform = PreProcessXNet(mask_func=train_mask_func,
                                           challenge=args.challenge,
                                           device=device,
                                           augment_data=args.augment_data,
                                           use_seed=False,
                                           crop_center=args.crop_center)
    input_val_transform = PreProcessXNet(mask_func=val_mask_func,
                                         challenge=args.challenge,
                                         device=device,
                                         augment_data=False,
                                         use_seed=True,
                                         crop_center=args.crop_center)

    output_train_transform = PostProcessXNet(challenge=args.challenge)
    output_val_transform = PostProcessXNet(challenge=args.challenge)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        phase_loss=nn.MSELoss(),
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        img_loss=LogSSIMLoss(filter_size=5).to(device=device),
        # img_loss=nn.L1Loss()
        x_loss=AlignmentLoss())

    data_chans = 1 if args.challenge == 'singlecoil' else 15  # Multicoil has 15 coils with 2 for real/imag

    model = XNet(in_chans=data_chans,
                 out_chans=data_chans,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 dilation=args.dilation,
                 res_scale=args.res_scale).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = XNetModelTrainer(args, model, optimizer, train_loader,
                               val_loader, input_train_transform,
                               input_val_transform, output_train_transform,
                               output_val_transform, losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning(
            f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.'
        )
Example #3
0
def train_img_to_rss(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    arguments = vars(
        args)  # Placed here for backward compatibility and convenience.
    args.center_fractions_train = arguments.get(
        'center_fractions_train', arguments.get('center_fractions'))
    args.center_fractions_val = arguments.get(
        'center_fractions_val', arguments.get('center_fractions'))
    args.accelerations_train = arguments.get('accelerations_train',
                                             arguments.get('accelerations'))
    args.accelerations_val = arguments.get('accelerations_val',
                                           arguments.get('accelerations'))

    if args.random_sampling:
        train_mask_func = RandomMaskFunc(args.center_fractions_train,
                                         args.accelerations_train)
        val_mask_func = RandomMaskFunc(args.center_fractions_val,
                                       args.accelerations_val)
    else:
        train_mask_func = UniformMaskFunc(args.center_fractions_train,
                                          args.accelerations_train)
        val_mask_func = UniformMaskFunc(args.center_fractions_val,
                                        args.accelerations_val)

    input_train_transform = PreProcessRSS(mask_func=train_mask_func,
                                          challenge=args.challenge,
                                          device=device,
                                          augment_data=args.augment_data,
                                          use_seed=False,
                                          fat_info=args.fat_info)
    input_val_transform = PreProcessRSS(mask_func=val_mask_func,
                                        challenge=args.challenge,
                                        device=device,
                                        augment_data=False,
                                        use_seed=True,
                                        fat_info=args.fat_info)

    output_train_transform = PostProcessRSS(challenge=args.challenge,
                                            residual_rss=args.residual_rss)
    output_val_transform = PostProcessRSS(challenge=args.challenge,
                                          residual_rss=args.residual_rss)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(rss_loss=SSIMLoss(filter_size=7).to(device=device))

    in_chans = 16 if args.fat_info else 15
    model = UNet(in_chans=in_chans,
                 out_chans=1,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_res_groups=args.num_res_groups,
                 num_res_blocks_per_group=args.num_res_blocks_per_group,
                 growth_rate=args.growth_rate,
                 num_dense_layers=args.num_dense_layers,
                 use_dense_ca=args.use_dense_ca,
                 num_res_layers=args.num_res_layers,
                 res_scale=args.res_scale,
                 reduction=args.reduction,
                 thick_base=args.thick_base).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                               milestones=args.milestones,
                                               gamma=args.lr_red_rate)
    trainer = ModelTrainerRSS(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model_concat()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning('Closing summary writer due to KeyboardInterrupt.')
Example #4
0
def train_complex(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2**args.num_pool_layers

    if args.random_sampling:
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    data_prefetch = Prefetch2Device(device)

    if args.train_method == 'WS2C':  # Semi-k-space learning.
        weight_func = SemiDistanceWeight(weight_type=args.weight_type)
        input_train_transform = PreProcessWSK(mask_func,
                                              weight_func,
                                              args.challenge,
                                              device,
                                              use_seed=False,
                                              divisor=divisor)
        input_val_transform = PreProcessWSK(mask_func,
                                            weight_func,
                                            args.challenge,
                                            device,
                                            use_seed=True,
                                            divisor=divisor)
        output_transform = WeightedReplacePostProcessSemiK(
            weighted=True, replace=args.replace)

    elif args.train_method == 'WK2C':  # k-space learning.
        weight_func = TiltedDistanceWeight(weight_type=args.weight_type,
                                           y_scale=args.y_scale)
        input_train_transform = PreProcessWK(mask_func,
                                             weight_func,
                                             args.challenge,
                                             device,
                                             use_seed=False,
                                             divisor=divisor)
        input_val_transform = PreProcessWK(mask_func,
                                           weight_func,
                                           args.challenge,
                                           device,
                                           use_seed=True,
                                           divisor=divisor)
        output_transform = WeightedReplacePostProcessK(weighted=True,
                                                       replace=args.replace)
    else:
        raise NotImplementedError('Invalid train method!')

    # DataLoaders
    train_loader, val_loader = create_custom_data_loaders(
        args, transform=data_prefetch)

    losses = dict(cmg_loss=nn.MSELoss(reduction='mean'))

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    # model = UNetModel(
    #     in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers,
    #     num_groups=args.num_groups, use_residual=args.use_residual, pool_type=args.pool_type, use_skip=args.use_skip,
    #     use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp,
    #     use_sa=args.use_sa, sa_kernel_size=args.sa_kernel_size, sa_dilation=args.sa_dilation, use_cap=args.use_cap,
    #     use_cmp=args.use_cmp).to(device)

    model = UNetModelKSSE(in_chans=data_chans,
                          out_chans=data_chans,
                          chans=args.chans,
                          num_pool_layers=args.num_pool_layers,
                          num_groups=args.num_groups,
                          use_residual=args.use_residual,
                          pool_type=args.pool_type,
                          use_skip=args.use_skip,
                          min_ext_size=args.min_ext_size,
                          max_ext_size=args.max_ext_size,
                          ext_mode=args.ext_mode,
                          use_ca=args.use_ca,
                          reduction=args.reduction,
                          use_gap=args.use_gap,
                          use_gmp=args.use_gmp,
                          use_sa=args.use_sa,
                          sa_kernel_size=args.sa_kernel_size,
                          sa_dilation=args.sa_dilation,
                          use_cap=args.use_cap,
                          use_cmp=args.use_cmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerCOMPLEX(args, model, optimizer, train_loader,
                                  val_loader, input_train_transform,
                                  input_val_transform, output_transform,
                                  losses, scheduler)

    trainer.train_model()
Example #5
0
def main(args):
    from models.edsr_unet import UNet  # Moving import line here to reduce confusion.
    from data.input_transforms import PreProcessIMG, Prefetch2Device
    from data.rss_inputs import PreProcessRSS
    from eval.input_test_transform import PreProcessTestIMG, PreProcessValIMG, PreProcessTestRSS
    from eval.output_test_transforms import PostProcessTestIMG, PostProcessTestRSS
    from train.subsample import RandomMaskFunc

    # Selecting device
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
    else:
        device = torch.device('cpu')

    print(f'Device {device} has been selected.')

    # data_chans = 1 if args.challenge == 'singlecoil' else 15
    model = UNet(in_chans=15,
                 out_chans=1,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 res_scale=args.res_scale,
                 use_residual=args.use_residual,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    dataset = CustomSliceData(root=args.data_root,
                              transform=Prefetch2Device(device),
                              challenge=args.challenge,
                              sample_rate=1,
                              start_slice=0,
                              use_gt=False)

    data_loader = DataLoader(dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=args.num_workers,
                             collate_fn=temp_collate_fn,
                             pin_memory=False)

    mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    # divisor = 2 ** args.num_pool_layers  # For UNet size fitting.

    # This is for the validation set, not the test set. The test set requires a different pre-processing function.
    if Path(args.data_root).name.endswith('val'):
        # pre_processing = PreProcessIMG(mask_func=mask_func, challenge=args.challenge, device=device,
        #                                augment_data=False, use_seed=True, crop_center=True)
        pre_processing = PreProcessRSS(mask_func=mask_func,
                                       challenge=args.challenge,
                                       device=device,
                                       augment_data=False,
                                       use_seed=True)
    elif Path(args.data_root).name.endswith('test_v2'):
        pre_processing = PreProcessTestRSS(challenge=args.challenge,
                                           device=device)
        # pre_processing = PreProcessTestIMG(challenge=args.challenge, device=device, crop_center=True)
    else:
        raise NotImplementedError(
            'Invalid data root. If using the original test set, please change to test_v2.'
        )

    # post_processing = PostProcessTestIMG(challenge=args.challenge)
    post_processing = PostProcessTestRSS(challenge=args.challenge,
                                         residual_rss=args.residual_rss)

    # Single acceleration, single model version.
    evaluator = ModelEvaluator(model, args.checkpoint_path, args.challenge,
                               data_loader, pre_processing, post_processing,
                               args.data_root, args.out_dir, device)

    # evaluator = MultiAccelerationModelEvaluator(
    #     model=model, checkpoint_path_4=args.checkpoint_path_4, checkpoint_path_8=args.checkpoint_path_8,
    #     challenge=args.challenge, data_loader=data_loader, pre_processing=pre_processing,
    #     post_processing=post_processing, data_root=args.data_root, out_dir=args.out_dir, device=device
    # )

    evaluator.create_and_save_reconstructions()
Example #6
0
def train_cmg_to_img_direct(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # # UNET architecture requires that all inputs be dividable by some power of 2.
    # divisor = 2 ** args.num_pool_layers

    arguments = vars(
        args)  # Placed here for backward compatibility and convenience.
    args.center_fractions_train = arguments.get(
        'center_fractions_train', arguments.get('center_fractions'))
    args.center_fractions_val = arguments.get(
        'center_fractions_val', arguments.get('center_fractions'))
    args.accelerations_train = arguments.get('accelerations_train',
                                             arguments.get('accelerations'))
    args.accelerations_val = arguments.get('accelerations_val',
                                           arguments.get('accelerations'))

    if args.random_sampling:
        train_mask_func = RandomMaskFunc(args.center_fractions_train,
                                         args.accelerations_train)
        val_mask_func = RandomMaskFunc(args.center_fractions_val,
                                       args.accelerations_val)
    else:
        train_mask_func = UniformMaskFunc(args.center_fractions_train,
                                          args.accelerations_train)
        val_mask_func = UniformMaskFunc(args.center_fractions_val,
                                        args.accelerations_val)

    input_train_transform = PreProcessCMGIMG(mask_func=train_mask_func,
                                             challenge=args.challenge,
                                             device=device,
                                             augment_data=args.augment_data,
                                             use_seed=False,
                                             crop_center=args.crop_center)
    input_val_transform = PreProcessCMGIMG(mask_func=val_mask_func,
                                           challenge=args.challenge,
                                           device=device,
                                           augment_data=False,
                                           use_seed=True,
                                           crop_center=args.crop_center)

    output_train_transform = PostProcessCMGIMG(challenge=args.challenge,
                                               output_mode='img')
    output_val_transform = PostProcessCMGIMG(challenge=args.challenge,
                                             output_mode='img')

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        # img_loss=LogSSIMLoss(filter_size=7).to(device=device)
        img_loss=nn.L1Loss()
        # img_loss=L1SSIMLoss(filter_size=7, l1_ratio=args.l1_ratio).to(device=device)
    )

    # data_chans = 1 if args.challenge == 'singlecoil' else 15
    model = UNet(in_chans=45,
                 out_chans=15,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 res_scale=args.res_scale,
                 use_residual=False,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerI2I(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning('Closing summary writer due to KeyboardInterrupt.')
Example #7
0
def train_cmg_and_img(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    if args.random_sampling:  # Same as in the challenge
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    input_train_transform = PreProcessWSemiKCC(mask_func=mask_func,
                                               weight_func=no_weight,
                                               challenge=args.challenge,
                                               device=device,
                                               use_seed=False)
    input_val_transform = PreProcessWSemiKCC(mask_func=mask_func,
                                             weight_func=no_weight,
                                             challenge=args.challenge,
                                             device=device,
                                             use_seed=True)

    output_train_transform = PostProcessWSemiKCC(
        args.challenge, weighted=False, residual_acs=args.residual_acs)
    output_val_transform = PostProcessWSemiKCC(args.challenge,
                                               weighted=False,
                                               residual_acs=args.residual_acs)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        cmg_loss=nn.MSELoss(),
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        # img_loss=LogSSIMLoss(filter_size=7).to(device=device)
        img_loss=nn.L1Loss())

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    model = UNet(in_chans=data_chans,
                 out_chans=data_chans,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 num_groups=args.num_groups,
                 negative_slope=args.negative_slope,
                 use_residual=args.use_residual,
                 interp_mode=args.interp_mode,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerCI(args, model, optimizer, train_loader, val_loader,
                             input_train_transform, input_val_transform,
                             output_train_transform, output_val_transform,
                             losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning(
            f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.'
        )