Пример #1
0
def train_img_to_rss(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    arguments = vars(
        args)  # Placed here for backward compatibility and convenience.
    args.center_fractions_train = arguments.get(
        'center_fractions_train', arguments.get('center_fractions'))
    args.center_fractions_val = arguments.get(
        'center_fractions_val', arguments.get('center_fractions'))
    args.accelerations_train = arguments.get('accelerations_train',
                                             arguments.get('accelerations'))
    args.accelerations_val = arguments.get('accelerations_val',
                                           arguments.get('accelerations'))

    if args.random_sampling:
        train_mask_func = RandomMaskFunc(args.center_fractions_train,
                                         args.accelerations_train)
        val_mask_func = RandomMaskFunc(args.center_fractions_val,
                                       args.accelerations_val)
    else:
        train_mask_func = UniformMaskFunc(args.center_fractions_train,
                                          args.accelerations_train)
        val_mask_func = UniformMaskFunc(args.center_fractions_val,
                                        args.accelerations_val)

    input_train_transform = PreProcessRSS(mask_func=train_mask_func,
                                          challenge=args.challenge,
                                          device=device,
                                          augment_data=args.augment_data,
                                          use_seed=False,
                                          fat_info=args.fat_info)
    input_val_transform = PreProcessRSS(mask_func=val_mask_func,
                                        challenge=args.challenge,
                                        device=device,
                                        augment_data=False,
                                        use_seed=True,
                                        fat_info=args.fat_info)

    output_train_transform = PostProcessRSS(challenge=args.challenge,
                                            residual_rss=args.residual_rss)
    output_val_transform = PostProcessRSS(challenge=args.challenge,
                                          residual_rss=args.residual_rss)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(rss_loss=SSIMLoss(filter_size=7).to(device=device))

    in_chans = 16 if args.fat_info else 15
    model = UNet(in_chans=in_chans,
                 out_chans=1,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_res_groups=args.num_res_groups,
                 num_res_blocks_per_group=args.num_res_blocks_per_group,
                 growth_rate=args.growth_rate,
                 num_dense_layers=args.num_dense_layers,
                 use_dense_ca=args.use_dense_ca,
                 num_res_layers=args.num_res_layers,
                 res_scale=args.res_scale,
                 reduction=args.reduction,
                 thick_base=args.thick_base).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                               milestones=args.milestones,
                                               gamma=args.lr_red_rate)
    trainer = ModelTrainerRSS(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model_concat()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning('Closing summary writer due to KeyboardInterrupt.')
Пример #2
0
def train_cmg_to_img(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    if args.random_sampling:
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    input_train_transform = PreProcessCMG(mask_func,
                                          args.challenge,
                                          device,
                                          augment_data=args.augment_data,
                                          use_seed=False,
                                          crop_center=args.crop_center)
    input_val_transform = PreProcessCMG(mask_func,
                                        args.challenge,
                                        device,
                                        augment_data=False,
                                        use_seed=True,
                                        crop_center=args.crop_center)

    output_train_transform = PostProcessCMG(challenge=args.challenge,
                                            residual_acs=args.residual_acs)
    output_val_transform = PostProcessCMG(challenge=args.challenge,
                                          residual_acs=args.residual_acs)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        img_loss=LogSSIMLoss(filter_size=7).to(device)
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        # img_loss=nn.L1Loss()
    )

    # model = UNet(
    #     in_chans=30, out_chans=30, chans=args.chans, num_pool_layers=args.num_pool_layers, num_groups=args.num_groups,
    #     negative_slope=args.negative_slope, use_residual=args.use_residual, interp_mode=args.interp_mode,
    #     use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device)

    model = UNet(in_chans=30,
                 out_chans=30,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 use_residual=args.use_residual,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerIMG(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning('Closing summary writer due to KeyboardInterrupt.')
Пример #3
0
def train_complex(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2**args.num_pool_layers

    if args.random_sampling:
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    data_prefetch = Prefetch2Device(device)

    if args.train_method == 'WS2C':  # Semi-k-space learning.
        weight_func = SemiDistanceWeight(weight_type=args.weight_type)
        input_train_transform = PreProcessWSK(mask_func,
                                              weight_func,
                                              args.challenge,
                                              device,
                                              use_seed=False,
                                              divisor=divisor)
        input_val_transform = PreProcessWSK(mask_func,
                                            weight_func,
                                            args.challenge,
                                            device,
                                            use_seed=True,
                                            divisor=divisor)
        output_transform = WeightedReplacePostProcessSemiK(
            weighted=True, replace=args.replace)

    elif args.train_method == 'WK2C':  # k-space learning.
        weight_func = TiltedDistanceWeight(weight_type=args.weight_type,
                                           y_scale=args.y_scale)
        input_train_transform = PreProcessWK(mask_func,
                                             weight_func,
                                             args.challenge,
                                             device,
                                             use_seed=False,
                                             divisor=divisor)
        input_val_transform = PreProcessWK(mask_func,
                                           weight_func,
                                           args.challenge,
                                           device,
                                           use_seed=True,
                                           divisor=divisor)
        output_transform = WeightedReplacePostProcessK(weighted=True,
                                                       replace=args.replace)
    else:
        raise NotImplementedError('Invalid train method!')

    # DataLoaders
    train_loader, val_loader = create_custom_data_loaders(
        args, transform=data_prefetch)

    losses = dict(cmg_loss=nn.MSELoss(reduction='mean'))

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    # model = UNetModel(
    #     in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers,
    #     num_groups=args.num_groups, use_residual=args.use_residual, pool_type=args.pool_type, use_skip=args.use_skip,
    #     use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp,
    #     use_sa=args.use_sa, sa_kernel_size=args.sa_kernel_size, sa_dilation=args.sa_dilation, use_cap=args.use_cap,
    #     use_cmp=args.use_cmp).to(device)

    model = UNetModelKSSE(in_chans=data_chans,
                          out_chans=data_chans,
                          chans=args.chans,
                          num_pool_layers=args.num_pool_layers,
                          num_groups=args.num_groups,
                          use_residual=args.use_residual,
                          pool_type=args.pool_type,
                          use_skip=args.use_skip,
                          min_ext_size=args.min_ext_size,
                          max_ext_size=args.max_ext_size,
                          ext_mode=args.ext_mode,
                          use_ca=args.use_ca,
                          reduction=args.reduction,
                          use_gap=args.use_gap,
                          use_gmp=args.use_gmp,
                          use_sa=args.use_sa,
                          sa_kernel_size=args.sa_kernel_size,
                          sa_dilation=args.sa_dilation,
                          use_cap=args.use_cap,
                          use_cmp=args.use_cmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerCOMPLEX(args, model, optimizer, train_loader,
                                  val_loader, input_train_transform,
                                  input_val_transform, output_transform,
                                  losses, scheduler)

    trainer.train_model()
Пример #4
0
def train_cmg_to_img_direct(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # # UNET architecture requires that all inputs be dividable by some power of 2.
    # divisor = 2 ** args.num_pool_layers

    arguments = vars(
        args)  # Placed here for backward compatibility and convenience.
    args.center_fractions_train = arguments.get(
        'center_fractions_train', arguments.get('center_fractions'))
    args.center_fractions_val = arguments.get(
        'center_fractions_val', arguments.get('center_fractions'))
    args.accelerations_train = arguments.get('accelerations_train',
                                             arguments.get('accelerations'))
    args.accelerations_val = arguments.get('accelerations_val',
                                           arguments.get('accelerations'))

    if args.random_sampling:
        train_mask_func = RandomMaskFunc(args.center_fractions_train,
                                         args.accelerations_train)
        val_mask_func = RandomMaskFunc(args.center_fractions_val,
                                       args.accelerations_val)
    else:
        train_mask_func = UniformMaskFunc(args.center_fractions_train,
                                          args.accelerations_train)
        val_mask_func = UniformMaskFunc(args.center_fractions_val,
                                        args.accelerations_val)

    input_train_transform = PreProcessCMGIMG(mask_func=train_mask_func,
                                             challenge=args.challenge,
                                             device=device,
                                             augment_data=args.augment_data,
                                             use_seed=False,
                                             crop_center=args.crop_center)
    input_val_transform = PreProcessCMGIMG(mask_func=val_mask_func,
                                           challenge=args.challenge,
                                           device=device,
                                           augment_data=False,
                                           use_seed=True,
                                           crop_center=args.crop_center)

    output_train_transform = PostProcessCMGIMG(challenge=args.challenge,
                                               output_mode='img')
    output_val_transform = PostProcessCMGIMG(challenge=args.challenge,
                                             output_mode='img')

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        # img_loss=LogSSIMLoss(filter_size=7).to(device=device)
        img_loss=nn.L1Loss()
        # img_loss=L1SSIMLoss(filter_size=7, l1_ratio=args.l1_ratio).to(device=device)
    )

    # data_chans = 1 if args.challenge == 'singlecoil' else 15
    model = UNet(in_chans=45,
                 out_chans=15,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 res_scale=args.res_scale,
                 use_residual=False,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerI2I(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning('Closing summary writer due to KeyboardInterrupt.')
Пример #5
0
def train_k2i(args):

    # Maybe move this to args later.
    train_method = 'W2I'  # Weighted K-space to real-valued image.

    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2**args.num_pool_layers

    if args.random_sampling:
        mask_func = MaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    # This is optimized for SSD storage.
    # Sending to device should be inside the input transform for optimal performance on HDD.
    data_prefetch = Prefetch2Device(device)

    input_train_transform = WeightedPreProcessK(mask_func,
                                                args.challenge,
                                                device,
                                                use_seed=False,
                                                divisor=divisor)
    input_val_transform = WeightedPreProcessK(mask_func,
                                              args.challenge,
                                              device,
                                              use_seed=True,
                                              divisor=divisor)

    # DataLoaders
    train_loader, val_loader = create_custom_data_loaders(
        args, transform=data_prefetch)

    losses = dict(img_loss=nn.L1Loss(reduction='mean')
                  # img_loss=L1CSSIM7(reduction='mean', alpha=args.alpha)
                  )

    output_transform = WeightedReplacePostProcessK()

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    model = UNetSkipGN(in_chans=data_chans,
                       out_chans=data_chans,
                       chans=args.chans,
                       num_pool_layers=args.num_pool_layers,
                       num_groups=args.num_groups,
                       pool_type=args.pool_type,
                       use_skip=args.use_skip,
                       use_att=args.use_att,
                       reduction=args.reduction,
                       use_gap=args.use_gap,
                       use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerK2I(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_transform, losses, scheduler)

    trainer.train_model()
Пример #6
0
def train_complex_model(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    arguments = vars(
        args)  # Placed here for backward compatibility and convenience.
    args.center_fractions_train = arguments.get(
        'center_fractions_train', arguments.get('center_fractions'))
    args.center_fractions_val = arguments.get(
        'center_fractions_val', arguments.get('center_fractions'))
    args.accelerations_train = arguments.get('accelerations_train',
                                             arguments.get('accelerations'))
    args.accelerations_val = arguments.get('accelerations_val',
                                           arguments.get('accelerations'))

    if args.random_sampling:
        train_mask_func = RandomMaskFunc(args.center_fractions_train,
                                         args.accelerations_train)
        val_mask_func = RandomMaskFunc(args.center_fractions_val,
                                       args.accelerations_val)
    else:
        train_mask_func = UniformMaskFunc(args.center_fractions_train,
                                          args.accelerations_train)
        val_mask_func = UniformMaskFunc(args.center_fractions_val,
                                        args.accelerations_val)

    divisor = 2**args.num_pool_layers

    # weight_func = SemiDistanceWeight(weight_type=args.weight_type)
    weight_func = None

    input_train_transform = PreProcessComplexWSK(
        train_mask_func,
        weight_func,
        args.challenge,
        device,
        augment_data=args.augment_data,
        use_seed=False,
        crop_center=args.crop_center,
        crop_ud=args.crop_ud,
        divisor=divisor)

    input_val_transform = PreProcessComplexWSK(val_mask_func,
                                               weight_func,
                                               args.challenge,
                                               device,
                                               augment_data=False,
                                               use_seed=True,
                                               crop_center=args.crop_center,
                                               crop_ud=args.crop_ud,
                                               divisor=divisor)

    output_train_transform = PostProcessComplexWSK(challenge=args.challenge,
                                                   replace=args.replace,
                                                   weighted=False)
    output_val_transform = PostProcessComplexWSK(challenge=args.challenge,
                                                 replace=args.replace,
                                                 weighted=False)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(img_loss=nn.MSELoss(reduction='mean'))

    data_chans = 1 if args.challenge == 'singlecoil' else 15  # Multicoil has 15 coils with 2 for real/imag

    model = ComplexUNet(in_chans=data_chans,
                        out_chans=data_chans,
                        chans=args.chans,
                        num_pool_layers=args.num_pool_layers).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate)
    scheduler = None
    trainer = ModelTrainerIMG(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning(
            f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.'
        )
Пример #7
0
def train_cmg_and_img(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    if args.random_sampling:  # Same as in the challenge
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    input_train_transform = PreProcessWSemiKCC(mask_func=mask_func,
                                               weight_func=no_weight,
                                               challenge=args.challenge,
                                               device=device,
                                               use_seed=False)
    input_val_transform = PreProcessWSemiKCC(mask_func=mask_func,
                                             weight_func=no_weight,
                                             challenge=args.challenge,
                                             device=device,
                                             use_seed=True)

    output_train_transform = PostProcessWSemiKCC(
        args.challenge, weighted=False, residual_acs=args.residual_acs)
    output_val_transform = PostProcessWSemiKCC(args.challenge,
                                               weighted=False,
                                               residual_acs=args.residual_acs)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        cmg_loss=nn.MSELoss(),
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        # img_loss=LogSSIMLoss(filter_size=7).to(device=device)
        img_loss=nn.L1Loss())

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    model = UNet(in_chans=data_chans,
                 out_chans=data_chans,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 num_groups=args.num_groups,
                 negative_slope=args.negative_slope,
                 use_residual=args.use_residual,
                 interp_mode=args.interp_mode,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerCI(args, model, optimizer, train_loader, val_loader,
                             input_train_transform, input_val_transform,
                             output_train_transform, output_val_transform,
                             losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning(
            f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.'
        )