def train_xnet(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    arguments = vars(
        args)  # Placed here for backward compatibility and convenience.
    args.center_fractions_train = arguments.get(
        'center_fractions_train', arguments.get('center_fractions'))
    args.center_fractions_val = arguments.get(
        'center_fractions_val', arguments.get('center_fractions'))
    args.accelerations_train = arguments.get('accelerations_train',
                                             arguments.get('accelerations'))
    args.accelerations_val = arguments.get('accelerations_val',
                                           arguments.get('accelerations'))

    if args.random_sampling:
        train_mask_func = RandomMaskFunc(args.center_fractions_train,
                                         args.accelerations_train)
        val_mask_func = RandomMaskFunc(args.center_fractions_val,
                                       args.accelerations_val)
    else:
        train_mask_func = UniformMaskFunc(args.center_fractions_train,
                                          args.accelerations_train)
        val_mask_func = UniformMaskFunc(args.center_fractions_val,
                                        args.accelerations_val)

    input_train_transform = PreProcessXNet(mask_func=train_mask_func,
                                           challenge=args.challenge,
                                           device=device,
                                           augment_data=args.augment_data,
                                           use_seed=False,
                                           crop_center=args.crop_center)
    input_val_transform = PreProcessXNet(mask_func=val_mask_func,
                                         challenge=args.challenge,
                                         device=device,
                                         augment_data=False,
                                         use_seed=True,
                                         crop_center=args.crop_center)

    output_train_transform = PostProcessXNet(challenge=args.challenge)
    output_val_transform = PostProcessXNet(challenge=args.challenge)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        phase_loss=nn.MSELoss(),
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        img_loss=LogSSIMLoss(filter_size=5).to(device=device),
        # img_loss=nn.L1Loss()
        x_loss=AlignmentLoss())

    data_chans = 1 if args.challenge == 'singlecoil' else 15  # Multicoil has 15 coils with 2 for real/imag

    model = XNet(in_chans=data_chans,
                 out_chans=data_chans,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 dilation=args.dilation,
                 res_scale=args.res_scale).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = XNetModelTrainer(args, model, optimizer, train_loader,
                               val_loader, input_train_transform,
                               input_val_transform, output_train_transform,
                               output_val_transform, losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning(
            f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.'
        )
Beispiel #2
0
def train_cmg_to_img(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    if args.random_sampling:
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    input_train_transform = PreProcessCMG(mask_func,
                                          args.challenge,
                                          device,
                                          augment_data=args.augment_data,
                                          use_seed=False,
                                          crop_center=args.crop_center)
    input_val_transform = PreProcessCMG(mask_func,
                                        args.challenge,
                                        device,
                                        augment_data=False,
                                        use_seed=True,
                                        crop_center=args.crop_center)

    output_train_transform = PostProcessCMG(challenge=args.challenge,
                                            residual_acs=args.residual_acs)
    output_val_transform = PostProcessCMG(challenge=args.challenge,
                                          residual_acs=args.residual_acs)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        img_loss=LogSSIMLoss(filter_size=7).to(device)
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        # img_loss=nn.L1Loss()
    )

    # model = UNet(
    #     in_chans=30, out_chans=30, chans=args.chans, num_pool_layers=args.num_pool_layers, num_groups=args.num_groups,
    #     negative_slope=args.negative_slope, use_residual=args.use_residual, interp_mode=args.interp_mode,
    #     use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device)

    model = UNet(in_chans=30,
                 out_chans=30,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 use_residual=args.use_residual,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerIMG(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning('Closing summary writer due to KeyboardInterrupt.')
Beispiel #3
0
def train_img_to_rss(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    arguments = vars(
        args)  # Placed here for backward compatibility and convenience.
    args.center_fractions_train = arguments.get(
        'center_fractions_train', arguments.get('center_fractions'))
    args.center_fractions_val = arguments.get(
        'center_fractions_val', arguments.get('center_fractions'))
    args.accelerations_train = arguments.get('accelerations_train',
                                             arguments.get('accelerations'))
    args.accelerations_val = arguments.get('accelerations_val',
                                           arguments.get('accelerations'))

    if args.random_sampling:
        train_mask_func = RandomMaskFunc(args.center_fractions_train,
                                         args.accelerations_train)
        val_mask_func = RandomMaskFunc(args.center_fractions_val,
                                       args.accelerations_val)
    else:
        train_mask_func = UniformMaskFunc(args.center_fractions_train,
                                          args.accelerations_train)
        val_mask_func = UniformMaskFunc(args.center_fractions_val,
                                        args.accelerations_val)

    input_train_transform = PreProcessRSS(mask_func=train_mask_func,
                                          challenge=args.challenge,
                                          device=device,
                                          augment_data=args.augment_data,
                                          use_seed=False,
                                          fat_info=args.fat_info)
    input_val_transform = PreProcessRSS(mask_func=val_mask_func,
                                        challenge=args.challenge,
                                        device=device,
                                        augment_data=False,
                                        use_seed=True,
                                        fat_info=args.fat_info)

    output_train_transform = PostProcessRSS(challenge=args.challenge,
                                            residual_rss=args.residual_rss)
    output_val_transform = PostProcessRSS(challenge=args.challenge,
                                          residual_rss=args.residual_rss)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(rss_loss=SSIMLoss(filter_size=7).to(device=device))

    in_chans = 16 if args.fat_info else 15
    model = UNet(in_chans=in_chans,
                 out_chans=1,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_res_groups=args.num_res_groups,
                 num_res_blocks_per_group=args.num_res_blocks_per_group,
                 growth_rate=args.growth_rate,
                 num_dense_layers=args.num_dense_layers,
                 use_dense_ca=args.use_dense_ca,
                 num_res_layers=args.num_res_layers,
                 res_scale=args.res_scale,
                 reduction=args.reduction,
                 thick_base=args.thick_base).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                               milestones=args.milestones,
                                               gamma=args.lr_red_rate)
    trainer = ModelTrainerRSS(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model_concat()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning('Closing summary writer due to KeyboardInterrupt.')
def train_cmg_to_img_direct(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # # UNET architecture requires that all inputs be dividable by some power of 2.
    # divisor = 2 ** args.num_pool_layers

    arguments = vars(
        args)  # Placed here for backward compatibility and convenience.
    args.center_fractions_train = arguments.get(
        'center_fractions_train', arguments.get('center_fractions'))
    args.center_fractions_val = arguments.get(
        'center_fractions_val', arguments.get('center_fractions'))
    args.accelerations_train = arguments.get('accelerations_train',
                                             arguments.get('accelerations'))
    args.accelerations_val = arguments.get('accelerations_val',
                                           arguments.get('accelerations'))

    if args.random_sampling:
        train_mask_func = RandomMaskFunc(args.center_fractions_train,
                                         args.accelerations_train)
        val_mask_func = RandomMaskFunc(args.center_fractions_val,
                                       args.accelerations_val)
    else:
        train_mask_func = UniformMaskFunc(args.center_fractions_train,
                                          args.accelerations_train)
        val_mask_func = UniformMaskFunc(args.center_fractions_val,
                                        args.accelerations_val)

    input_train_transform = PreProcessCMGIMG(mask_func=train_mask_func,
                                             challenge=args.challenge,
                                             device=device,
                                             augment_data=args.augment_data,
                                             use_seed=False,
                                             crop_center=args.crop_center)
    input_val_transform = PreProcessCMGIMG(mask_func=val_mask_func,
                                           challenge=args.challenge,
                                           device=device,
                                           augment_data=False,
                                           use_seed=True,
                                           crop_center=args.crop_center)

    output_train_transform = PostProcessCMGIMG(challenge=args.challenge,
                                               output_mode='img')
    output_val_transform = PostProcessCMGIMG(challenge=args.challenge,
                                             output_mode='img')

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        # img_loss=LogSSIMLoss(filter_size=7).to(device=device)
        img_loss=nn.L1Loss()
        # img_loss=L1SSIMLoss(filter_size=7, l1_ratio=args.l1_ratio).to(device=device)
    )

    # data_chans = 1 if args.challenge == 'singlecoil' else 15
    model = UNet(in_chans=45,
                 out_chans=15,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 res_scale=args.res_scale,
                 use_residual=False,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerI2I(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning('Closing summary writer due to KeyboardInterrupt.')
Beispiel #5
0
def train_cmg_and_img(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    if args.random_sampling:  # Same as in the challenge
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    input_train_transform = PreProcessWSemiKCC(mask_func=mask_func,
                                               weight_func=no_weight,
                                               challenge=args.challenge,
                                               device=device,
                                               use_seed=False)
    input_val_transform = PreProcessWSemiKCC(mask_func=mask_func,
                                             weight_func=no_weight,
                                             challenge=args.challenge,
                                             device=device,
                                             use_seed=True)

    output_train_transform = PostProcessWSemiKCC(
        args.challenge, weighted=False, residual_acs=args.residual_acs)
    output_val_transform = PostProcessWSemiKCC(args.challenge,
                                               weighted=False,
                                               residual_acs=args.residual_acs)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        cmg_loss=nn.MSELoss(),
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        # img_loss=LogSSIMLoss(filter_size=7).to(device=device)
        img_loss=nn.L1Loss())

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    model = UNet(in_chans=data_chans,
                 out_chans=data_chans,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 num_groups=args.num_groups,
                 negative_slope=args.negative_slope,
                 use_residual=args.use_residual,
                 interp_mode=args.interp_mode,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerCI(args, model, optimizer, train_loader, val_loader,
                             input_train_transform, input_val_transform,
                             output_train_transform, output_val_transform,
                             losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning(
            f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.'
        )
Beispiel #6
0
def train_cmg_and_img(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2 ** args.num_pool_layers

    if args.random_sampling:
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    if args.train_method == 'WSemi2CI':  # Semi-k-space learning.
        weight_func = SemiDistanceWeight(weight_type=args.weight_type)
        input_train_transform = PreProcessWSK(mask_func, weight_func, args.challenge, device,
                                              use_seed=False, divisor=divisor)
        input_val_transform = PreProcessWSK(mask_func, weight_func, args.challenge, device,
                                            use_seed=True, divisor=divisor)

        output_train_transform = PostProcessWSemiK(weighted=True, replace=False, residual_acs=args.residual_acs)
        output_val_transform = PostProcessWSemiK(weighted=True, replace=args.replace, residual_acs=args.residual_acs)

    elif args.train_method == 'WK2CI':  # k-space learning.
        weight_func = TiltedDistanceWeight(weight_type=args.weight_type, y_scale=args.y_scale)
        input_train_transform = PreProcessWK(mask_func, weight_func, args.challenge, device,
                                             use_seed=False, divisor=divisor)
        input_val_transform = PreProcessWK(mask_func, weight_func, args.challenge, device,
                                           use_seed=True, divisor=divisor)

        output_train_transform = PostProcessWK(weighted=True, replace=False, residual_acs=args.residual_acs)
        output_val_transform = PostProcessWK(weighted=True, replace=args.replace, residual_acs=args.residual_acs)
    else:
        raise NotImplementedError('Invalid train method!')

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        cmg_loss=nn.MSELoss(reduction='mean'),
        img_loss=nn.L1Loss(reduction='mean')  # Change to SSIM later.
    )

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    model = UNet(in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers,
                 num_groups=args.num_groups, negative_slope=args.negative_slope, use_residual=args.use_residual,
                 interp_mode=args.interp_mode, use_ca=args.use_ca, reduction=args.reduction,
                 use_gap=args.use_gap, use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate)

    trainer = ModelTrainerCI(args, model, optimizer, train_loader, val_loader, input_train_transform,
                             input_val_transform, output_train_transform, output_val_transform, losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning(f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.')