def __init__(self, args, model, optimizer, train_loader, val_loader,
                 input_train_transform, input_val_transform, output_transform, losses, scheduler=None):

        # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs.
        if multiprocessing.get_start_method(allow_none=True) is None:
            multiprocessing.set_start_method(method='spawn')

        self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        assert callable(input_train_transform) and callable(input_val_transform), \
            'input_transforms must be callable functions.'
        # I think this would be best practice.
        assert isinstance(output_transform, nn.Module), '`output_transform` must be a Pytorch Module.'

        # 'losses' is expected to be a dictionary.
        # Even composite losses should be a single loss module with multiple outputs.
        losses = nn.ModuleDict(losses)

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError('`scheduler` must be a Pytorch Learning Rate Scheduler.')

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(len(val_loader.dataset) // (args.max_images * args.batch_size))

        self.manager = CheckpointManager(model, optimizer, mode='min', save_best_only=args.save_best_only,
                                         ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.manager.load(load_dir=args.prev_model_ckpt, load_optimizer=False)

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.input_train_transform = input_train_transform
        self.input_val_transform = input_val_transform
        self.output_transform = output_transform
        self.losses = losses
        self.scheduler = scheduler

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.smoothing_factor = args.smoothing_factor
        self.use_slice_metrics = args.use_slice_metrics
        self.writer = SummaryWriter(str(args.log_path))
Exemplo n.º 2
0
def train_img_to_rss(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    arguments = vars(
        args)  # Placed here for backward compatibility and convenience.
    args.center_fractions_train = arguments.get(
        'center_fractions_train', arguments.get('center_fractions'))
    args.center_fractions_val = arguments.get(
        'center_fractions_val', arguments.get('center_fractions'))
    args.accelerations_train = arguments.get('accelerations_train',
                                             arguments.get('accelerations'))
    args.accelerations_val = arguments.get('accelerations_val',
                                           arguments.get('accelerations'))

    if args.random_sampling:
        train_mask_func = RandomMaskFunc(args.center_fractions_train,
                                         args.accelerations_train)
        val_mask_func = RandomMaskFunc(args.center_fractions_val,
                                       args.accelerations_val)
    else:
        train_mask_func = UniformMaskFunc(args.center_fractions_train,
                                          args.accelerations_train)
        val_mask_func = UniformMaskFunc(args.center_fractions_val,
                                        args.accelerations_val)

    input_train_transform = PreProcessRSS(mask_func=train_mask_func,
                                          challenge=args.challenge,
                                          device=device,
                                          augment_data=args.augment_data,
                                          use_seed=False,
                                          fat_info=args.fat_info)
    input_val_transform = PreProcessRSS(mask_func=val_mask_func,
                                        challenge=args.challenge,
                                        device=device,
                                        augment_data=False,
                                        use_seed=True,
                                        fat_info=args.fat_info)

    output_train_transform = PostProcessRSS(challenge=args.challenge,
                                            residual_rss=args.residual_rss)
    output_val_transform = PostProcessRSS(challenge=args.challenge,
                                          residual_rss=args.residual_rss)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(rss_loss=SSIMLoss(filter_size=7).to(device=device))

    in_chans = 16 if args.fat_info else 15
    model = UNet(in_chans=in_chans,
                 out_chans=1,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_res_groups=args.num_res_groups,
                 num_res_blocks_per_group=args.num_res_blocks_per_group,
                 growth_rate=args.growth_rate,
                 num_dense_layers=args.num_dense_layers,
                 use_dense_ca=args.use_dense_ca,
                 num_res_layers=args.num_res_layers,
                 res_scale=args.res_scale,
                 reduction=args.reduction,
                 thick_base=args.thick_base).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                               milestones=args.milestones,
                                               gamma=args.lr_red_rate)
    trainer = ModelTrainerRSS(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model_concat()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning('Closing summary writer due to KeyboardInterrupt.')
Exemplo n.º 3
0
def train_cmg_to_img_direct(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # # UNET architecture requires that all inputs be dividable by some power of 2.
    # divisor = 2 ** args.num_pool_layers

    arguments = vars(
        args)  # Placed here for backward compatibility and convenience.
    args.center_fractions_train = arguments.get(
        'center_fractions_train', arguments.get('center_fractions'))
    args.center_fractions_val = arguments.get(
        'center_fractions_val', arguments.get('center_fractions'))
    args.accelerations_train = arguments.get('accelerations_train',
                                             arguments.get('accelerations'))
    args.accelerations_val = arguments.get('accelerations_val',
                                           arguments.get('accelerations'))

    if args.random_sampling:
        train_mask_func = RandomMaskFunc(args.center_fractions_train,
                                         args.accelerations_train)
        val_mask_func = RandomMaskFunc(args.center_fractions_val,
                                       args.accelerations_val)
    else:
        train_mask_func = UniformMaskFunc(args.center_fractions_train,
                                          args.accelerations_train)
        val_mask_func = UniformMaskFunc(args.center_fractions_val,
                                        args.accelerations_val)

    input_train_transform = PreProcessCMGIMG(mask_func=train_mask_func,
                                             challenge=args.challenge,
                                             device=device,
                                             augment_data=args.augment_data,
                                             use_seed=False,
                                             crop_center=args.crop_center)
    input_val_transform = PreProcessCMGIMG(mask_func=val_mask_func,
                                           challenge=args.challenge,
                                           device=device,
                                           augment_data=False,
                                           use_seed=True,
                                           crop_center=args.crop_center)

    output_train_transform = PostProcessCMGIMG(challenge=args.challenge,
                                               output_mode='img')
    output_val_transform = PostProcessCMGIMG(challenge=args.challenge,
                                             output_mode='img')

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        # img_loss=LogSSIMLoss(filter_size=7).to(device=device)
        img_loss=nn.L1Loss()
        # img_loss=L1SSIMLoss(filter_size=7, l1_ratio=args.l1_ratio).to(device=device)
    )

    # data_chans = 1 if args.challenge == 'singlecoil' else 15
    model = UNet(in_chans=45,
                 out_chans=15,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 res_scale=args.res_scale,
                 use_residual=False,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerI2I(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning('Closing summary writer due to KeyboardInterrupt.')
Exemplo n.º 4
0
def train_complex(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2**args.num_pool_layers

    if args.random_sampling:
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    data_prefetch = Prefetch2Device(device)

    if args.train_method == 'WS2C':  # Semi-k-space learning.
        weight_func = SemiDistanceWeight(weight_type=args.weight_type)
        input_train_transform = PreProcessWSK(mask_func,
                                              weight_func,
                                              args.challenge,
                                              device,
                                              use_seed=False,
                                              divisor=divisor)
        input_val_transform = PreProcessWSK(mask_func,
                                            weight_func,
                                            args.challenge,
                                            device,
                                            use_seed=True,
                                            divisor=divisor)
        output_transform = WeightedReplacePostProcessSemiK(
            weighted=True, replace=args.replace)

    elif args.train_method == 'WK2C':  # k-space learning.
        weight_func = TiltedDistanceWeight(weight_type=args.weight_type,
                                           y_scale=args.y_scale)
        input_train_transform = PreProcessWK(mask_func,
                                             weight_func,
                                             args.challenge,
                                             device,
                                             use_seed=False,
                                             divisor=divisor)
        input_val_transform = PreProcessWK(mask_func,
                                           weight_func,
                                           args.challenge,
                                           device,
                                           use_seed=True,
                                           divisor=divisor)
        output_transform = WeightedReplacePostProcessK(weighted=True,
                                                       replace=args.replace)
    else:
        raise NotImplementedError('Invalid train method!')

    # DataLoaders
    train_loader, val_loader = create_custom_data_loaders(
        args, transform=data_prefetch)

    losses = dict(cmg_loss=nn.MSELoss(reduction='mean'))

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    # model = UNetModel(
    #     in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers,
    #     num_groups=args.num_groups, use_residual=args.use_residual, pool_type=args.pool_type, use_skip=args.use_skip,
    #     use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp,
    #     use_sa=args.use_sa, sa_kernel_size=args.sa_kernel_size, sa_dilation=args.sa_dilation, use_cap=args.use_cap,
    #     use_cmp=args.use_cmp).to(device)

    model = UNetModelKSSE(in_chans=data_chans,
                          out_chans=data_chans,
                          chans=args.chans,
                          num_pool_layers=args.num_pool_layers,
                          num_groups=args.num_groups,
                          use_residual=args.use_residual,
                          pool_type=args.pool_type,
                          use_skip=args.use_skip,
                          min_ext_size=args.min_ext_size,
                          max_ext_size=args.max_ext_size,
                          ext_mode=args.ext_mode,
                          use_ca=args.use_ca,
                          reduction=args.reduction,
                          use_gap=args.use_gap,
                          use_gmp=args.use_gmp,
                          use_sa=args.use_sa,
                          sa_kernel_size=args.sa_kernel_size,
                          sa_dilation=args.sa_dilation,
                          use_cap=args.use_cap,
                          use_cmp=args.use_cmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerCOMPLEX(args, model, optimizer, train_loader,
                                  val_loader, input_train_transform,
                                  input_val_transform, output_transform,
                                  losses, scheduler)

    trainer.train_model()
Exemplo n.º 5
0
def train_cmg_to_img(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    if args.random_sampling:
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    input_train_transform = PreProcessCMG(mask_func,
                                          args.challenge,
                                          device,
                                          augment_data=args.augment_data,
                                          use_seed=False,
                                          crop_center=args.crop_center)
    input_val_transform = PreProcessCMG(mask_func,
                                        args.challenge,
                                        device,
                                        augment_data=False,
                                        use_seed=True,
                                        crop_center=args.crop_center)

    output_train_transform = PostProcessCMG(challenge=args.challenge,
                                            residual_acs=args.residual_acs)
    output_val_transform = PostProcessCMG(challenge=args.challenge,
                                          residual_acs=args.residual_acs)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        img_loss=LogSSIMLoss(filter_size=7).to(device)
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        # img_loss=nn.L1Loss()
    )

    # model = UNet(
    #     in_chans=30, out_chans=30, chans=args.chans, num_pool_layers=args.num_pool_layers, num_groups=args.num_groups,
    #     negative_slope=args.negative_slope, use_residual=args.use_residual, interp_mode=args.interp_mode,
    #     use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device)

    model = UNet(in_chans=30,
                 out_chans=30,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 use_residual=args.use_residual,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerIMG(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_train_transform, output_val_transform,
                              losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning('Closing summary writer due to KeyboardInterrupt.')
Exemplo n.º 6
0
def main(argv):  # argv are non-flag parameters
    del argv
    start = time()
    tf.print('Tensorflow Engaged')
    run_number, run_name = initialize(FLAGS.ckpt_dir)

    log_path = Path(FLAGS.log_dir)
    log_path.mkdir(exist_ok=True)
    log_path = log_path / run_name
    log_path.mkdir(exist_ok=False)
    logger = get_logger(__name__)

    data_path = Path(FLAGS.data_dir)
    train_path = data_path / f'{FLAGS.challenge}_train'
    val_path = data_path / f'{FLAGS.challenge}_val'

    train_dataset = HDF5Sequence(data_dir=train_path,
                                 batch_size=FLAGS.batch_size,
                                 training=True,
                                 as_tensors=False)
    val_dataset = HDF5Sequence(data_dir=val_path,
                               batch_size=FLAGS.batch_size,
                               training=False,
                               as_tensors=False)

    model = make_unet_model(scope='UNET', input_shape=(320, 320, 1))
    # multi_model = tf.keras.utils.multi_gpu_model(model, gpus=2)
    multi_model = model

    tf.keras.utils.plot_model(model,
                              to_file=log_path / f'model_{run_number:02d}.png',
                              show_shapes=True)
    model.summary()

    ckpt_path = Path(FLAGS.ckpt_dir) / run_name
    ckpt_path.mkdir(exist_ok=True)
    ckpt_path = ckpt_path / '{epoch:04d}.ckpt'
    ckpt_path = str(ckpt_path)

    optimizer = tf.keras.optimizers.Adam(lr=FLAGS.lr)

    checkpoint = ModelCheckpoint(filepath=ckpt_path,
                                 verbose=1,
                                 save_best_only=True)
    visualizer = TensorBoard(log_dir=log_path, batch_size=FLAGS.batch_size)

    model_config = json.loads(model.to_json())
    # Save FLAGS to a json file next to the tensorboard data
    save_dict_as_json(dict_data=FLAGS.flag_values_dict(),
                      log_dir=str(log_path),
                      save_name=f'{run_name}_FLAGS')
    save_dict_as_json(dict_data=model_config,
                      log_dir=str(log_path),
                      save_name=f'{run_name}_config')

    multi_model.compile(
        optimizer=optimizer,
        loss='mae',
        metrics=[batch_ssim, batch_psnr, batch_msssim, batch_nmse])
    multi_model.fit_generator(train_dataset,
                              epochs=FLAGS.num_epochs,
                              verbose=1,
                              callbacks=[checkpoint, visualizer],
                              validation_data=val_dataset,
                              workers=4,
                              use_multiprocessing=True)

    finish = int(time() - start)
    logger.info(
        f'Finished Training model. Time: {finish // 3600:02d}hrs {(finish // 60) % 60}min {finish % 60}s'
    )
    return 0
Exemplo n.º 7
0
def main():

    batch_size = 8
    num_workers = 4
    init_lr = 2E-4
    gpu = 1  # Set to None for CPU mode.
    num_epochs = 500
    verbose = False
    save_best_only = True

    data_root = '/home/veritas/PycharmProjects/PA1/data'
    log_root = '/home/veritas/PycharmProjects/PA1/logs'
    ckpt_root = '/home/veritas/PycharmProjects/PA1/checkpoints'

    ckpt_path = Path(ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(log_root)
    log_path.mkdir(exist_ok=True)
    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    if (gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{gpu}')
    else:
        device = torch.device('cpu')

    # Do more fancy transforms later.
    transform = torchvision.transforms.ToTensor()

    train_dataset = torchvision.datasets.CIFAR100(data_root,
                                                  train=True,
                                                  transform=transform,
                                                  download=True)
    val_dataset = torchvision.datasets.CIFAR100(data_root,
                                                train=False,
                                                transform=transform,
                                                download=True)

    train_loader = DataLoader(train_dataset,
                              batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True)
    val_loader = DataLoader(val_dataset,
                            batch_size,
                            shuffle=False,
                            num_workers=num_workers,
                            pin_memory=True)

    # Define model, optimizer, etc.
    model = torchvision.models.resnet50(pretrained=False,
                                        num_classes=100).to(device,
                                                            non_blocking=True)
    optimizer = optim.Adam(model.parameters(), lr=init_lr)
    # No softmax layer at the end necessary. Just need logits.
    loss_func = CrossEntropyLoss().to(device, non_blocking=True)

    # Getting data collection
    train_loss_sum = torch.as_tensor(0.)
    val_loss_sum = torch.as_tensor(0.)

    train_top1_correct = torch.as_tensor(0)
    val_top1_correct = torch.as_tensor(0)

    train_top5_correct = torch.as_tensor(0)
    val_top5_correct = torch.as_tensor(0)

    previous_best = 0.

    # Training loop. Please excuse my use of 1 based indexing here.
    logger.info('Beginning Training loop')
    for epoch in range(1, num_epochs + 1):
        tic = time()
        model.train()
        torch.autograd.set_grad_enabled(True)
        for idx, (images, labels) in enumerate(train_loader, start=1):

            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad()
            preds = model(images)
            # Pytorch uses (input, target) ordering. Their shapes are also different.
            step_loss = loss_func(preds, labels)
            step_loss.backward()
            optimizer.step()

            with torch.no_grad(
            ):  # Necessary for metric calculation without errors due to gradient accumulation.
                train_loss_sum += step_loss
                top1 = torch.argmax(preds, dim=1)
                _, top5 = torch.topk(preds, k=5)  # Get only the indices.
                train_top1_correct += torch.sum(torch.eq(top1, labels))
                train_top5_correct += torch.sum(
                    torch.eq(top5,
                             labels.unsqueeze(-1)))  # This is probably right.

                if verbose:
                    print(
                        f'Training loss Epoch {epoch:03d} Step {idx:03d}: {step_loss.item()}'
                    )

        else:
            toc = int(time() - tic)
            # Last step with small batch causes some inaccuracy but that is tolerable.
            epoch_loss = train_loss_sum.item() * batch_size / len(
                train_loader.dataset)
            epoch_top1_acc = train_top1_correct.item() / len(
                train_loader.dataset) * 100
            epoch_top5_acc = train_top5_correct.item() / len(
                train_loader.dataset) * 100

            msg = f'loss: {epoch_loss:.4f}, top1 accuracy: {epoch_top1_acc:.2f}%, top5 accuracy: {epoch_top5_acc:.2f}%'
            logger.info(f'Epoch {epoch:03d} Training. {msg} Time: {toc}s')

            # writer.add_scalar('train_loss', epoch_loss, global_step=epoch)
            # writer.add_scalar('train_acc', epoch_top1_acc, global_step=epoch)

            # Reset to 0. There must be a better way though...
            train_loss_sum = torch.as_tensor(0.)
            train_top1_correct = torch.as_tensor(0)
            train_top5_correct = torch.as_tensor(0)

        tic = time()
        model.eval()
        torch.autograd.set_grad_enabled(False)
        for idx, (images, labels) in enumerate(val_loader, start=1):

            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            preds = model(images)

            top1 = torch.argmax(preds, dim=1)
            _, top5 = torch.topk(preds, k=5)

            val_top1_correct += torch.sum(torch.eq(top1, labels))
            val_top5_correct += torch.sum(torch.eq(top5, labels.unsqueeze(-1)))

            step_loss = loss_func(preds, labels)
            val_loss_sum += step_loss

            if verbose:
                print(
                    f'Validation loss Epoch {epoch:03d} Step {idx:03d}: {step_loss.item()}'
                )

        else:
            toc = int(time() - tic)
            epoch_loss = val_loss_sum.item() * batch_size / len(
                val_loader.dataset)
            epoch_top1_acc = val_top1_correct.item() / len(
                val_loader.dataset) * 100
            epoch_top5_acc = val_top5_correct.item() / len(
                val_loader.dataset) * 100

            msg = f'loss: {epoch_loss:.4f}, top1 accuracy: {epoch_top1_acc:.2f}%, top5 accuracy: {epoch_top5_acc:.2f}%'
            logger.info(f'Epoch {epoch:03d} Validation. {msg} Time: {toc}s')

            # writer.add_scalar('val_loss', epoch_loss, global_step=epoch)
            # writer.add_scalar('val_acc', epoch_top1_acc, global_step=epoch)

            # Reset to 0. There must be a better way though...
            val_loss_sum = torch.as_tensor(0.)
            val_top1_correct = torch.as_tensor(0)
            val_top5_correct = torch.as_tensor(0)

        # Checkpoint generation. Only implemented for single GPU models, not multi-gpu models.
        # All comparisons are done with python numbers, not tensors.
        if epoch_top5_acc > previous_best:  # Assumes larger metric is better.
            logger.info(
                f'Top 5 Validation Accuracy in Epoch {epoch} has improved from '
                f'{previous_best:.2f}% to {epoch_top5_acc:.2f}%')
            previous_best = epoch_top5_acc

            save_dict = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }
            torch.save(save_dict, ckpt_path / f'epoch_{epoch:03d}.tar')

        else:
            logger.info(
                f'Top 5 Validation Accuracy in Epoch {epoch} has not improved from the previous best epoch'
            )

            if not save_best_only:
                save_dict = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }
                torch.save(save_dict, ckpt_path / f'epoch_{epoch:03d}.tar')
Exemplo n.º 8
0
    def __init__(self,
                 args,
                 model,
                 optimizer,
                 train_loader,
                 val_loader,
                 input_train_transform,
                 input_val_transform,
                 output_train_transform,
                 output_val_transform,
                 losses,
                 scheduler=None):

        # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs.
        if multiprocessing.get_start_method(allow_none=True) is None:
            multiprocessing.set_start_method(method='spawn')

        self.logger = get_logger(name=__name__,
                                 save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model,
                          nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(
            optimizer,
            optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        assert callable(input_train_transform) and callable(input_val_transform), \
            'input_transforms must be callable functions.'
        # I think this would be best practice.
        assert isinstance(output_train_transform, nn.Module) and isinstance(output_val_transform, nn.Module), \
            '`output_train_transform` and `output_val_transform` must be Pytorch Modules.'

        # 'losses' is expected to be a dictionary.
        # Even composite losses should be a single loss module with a tuple as its output.
        losses = nn.ModuleDict(losses)

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError(
                    '`scheduler` must be a Pytorch Learning Rate Scheduler.')

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(
                len(val_loader.dataset) // (args.max_images * args.batch_size))

        self.manager = CheckpointManager(model,
                                         optimizer,
                                         mode='min',
                                         save_best_only=args.save_best_only,
                                         ckpt_dir=args.ckpt_path,
                                         max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.manager.load(load_dir=args.prev_model_ckpt,
                              load_optimizer=False)

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.input_train_transform = input_train_transform
        self.input_val_transform = input_val_transform
        self.output_train_transform = output_train_transform
        self.output_val_transform = output_val_transform
        self.losses = losses
        self.scheduler = scheduler
        self.writer = SummaryWriter(str(args.log_path))

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.smoothing_factor = args.smoothing_factor
        self.shrink_scale = args.shrink_scale
        self.use_slice_metrics = args.use_slice_metrics

        # This part should get SSIM, not 1 - SSIM.
        self.ssim = SSIM(filter_size=7).to(
            device=args.device)  # Needed to cache the kernel.

        # Logging all components of the Model Trainer.
        # Train and Val input and output transforms are assumed to use the same input transform class.
        self.logger.info(f'''
        Summary of Model Trainer Components:
        Model: {get_class_name(model)}.
        Optimizer: {get_class_name(optimizer)}.
        Input Transforms: {get_class_name(input_val_transform)}.
        Output Transform: {get_class_name(output_val_transform)}.
        Image Domain Loss: {get_class_name(losses['img_loss'])}.
        Learning-Rate Scheduler: {get_class_name(scheduler)}.
        ''')  # This part has parts different for IMG and CMG losses!!
Exemplo n.º 9
0
def train_cmg_and_img(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    if args.random_sampling:  # Same as in the challenge
        mask_func = RandomMaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    input_train_transform = PreProcessWSemiKCC(mask_func=mask_func,
                                               weight_func=no_weight,
                                               challenge=args.challenge,
                                               device=device,
                                               use_seed=False)
    input_val_transform = PreProcessWSemiKCC(mask_func=mask_func,
                                             weight_func=no_weight,
                                             challenge=args.challenge,
                                             device=device,
                                             use_seed=True)

    output_train_transform = PostProcessWSemiKCC(
        args.challenge, weighted=False, residual_acs=args.residual_acs)
    output_val_transform = PostProcessWSemiKCC(args.challenge,
                                               weighted=False,
                                               residual_acs=args.residual_acs)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        cmg_loss=nn.MSELoss(),
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        # img_loss=LogSSIMLoss(filter_size=7).to(device=device)
        img_loss=nn.L1Loss())

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    model = UNet(in_chans=data_chans,
                 out_chans=data_chans,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 num_groups=args.num_groups,
                 negative_slope=args.negative_slope,
                 use_residual=args.use_residual,
                 interp_mode=args.interp_mode,
                 use_ca=args.use_ca,
                 reduction=args.reduction,
                 use_gap=args.use_gap,
                 use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerCI(args, model, optimizer, train_loader, val_loader,
                             input_train_transform, input_val_transform,
                             output_train_transform, output_val_transform,
                             losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning(
            f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.'
        )
Exemplo n.º 10
0
    def __init__(self,
                 args,
                 model,
                 optimizer,
                 train_loader,
                 val_loader,
                 post_processing,
                 loss_func,
                 metrics=None,
                 scheduler=None):
        self.logger = get_logger(name=__name__,
                                 save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model,
                          nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(
            optimizer,
            optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        # I think this would be best practice.
        assert isinstance(
            post_processing,
            nn.Module), '`post_processing_func` must be a Pytorch Module.'

        # This is not a mistake. Pytorch implements loss functions as modules.
        assert isinstance(
            loss_func,
            nn.Module), '`loss_func` must be a callable Pytorch Module.'

        if metrics is not None:
            assert isinstance(
                metrics, (list, tuple)), '`metrics` must be a list or tuple.'
            for metric in metrics:
                assert callable(
                    metric), 'All metrics must be callable functions.'

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError(
                    '`scheduler` must be a Pytorch Learning Rate Scheduler.')

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.post_processing_func = post_processing
        self.loss_func = loss_func  # I don't think it is necessary to send loss_func or metrics to device.
        self.metrics = metrics
        self.scheduler = scheduler

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.writer = SummaryWriter(logdir=str(args.log_path))

        # Display interval of 0 means no display of validation images on TensorBoard.
        self.display_interval = int(
            len(self.val_loader.dataset) //
            args.max_images) if (args.max_images > 0) else 0

        # # Writing model graph to TensorBoard. Results might not be very good.
        if args.add_graph:
            num_chans = 30 if args.challenge == 'multicoil' else 2
            example_inputs = torch.ones(size=(1, num_chans, 640, 328),
                                        device=args.device)
            self.writer.add_graph(model=model, input_to_model=example_inputs)
            del example_inputs  # Remove unnecessary tensor taking up memory.

        self.checkpointer = CheckpointManager(
            model=self.model,
            optimizer=self.optimizer,
            mode='min',
            save_best_only=args.save_best_only,
            ckpt_dir=args.ckpt_path,
            max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.checkpointer.load(load_dir=args.prev_model_ckpt,
                                   load_optimizer=False)
Exemplo n.º 11
0
def train_model(model, args):

    assert isinstance(model, nn.Module)

    # Beginning session.
    run_number, run_name = initialize(args.ckpt_root)

    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)
    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)
    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
    else:
        device = torch.device('cpu')

    # Saving args for later use.
    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    dataset_kwargs = dict(root=args.data_root, download=True)
    train_dataset = torchvision.datasets.CIFAR100(train=True,
                                                  transform=train_transform(),
                                                  **dataset_kwargs)
    val_dataset = torchvision.datasets.CIFAR100(train=False,
                                                transform=val_transform(),
                                                **dataset_kwargs)

    loader_kwargs = dict(batch_size=args.batch_size,
                         num_workers=args.num_workers,
                         pin_memory=True)
    train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs)
    val_loader = DataLoader(val_dataset, shuffle=False, **loader_kwargs)

    # Define model, optimizer, etc.
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    # No softmax layer at the end necessary. Just need logits.
    loss_func = nn.CrossEntropyLoss(reduction='mean').to(device)

    # LR scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.step_size,
                                          gamma=args.gamma)

    # Create checkpoint manager
    checkpointer = CheckpointManager(model,
                                     optimizer,
                                     mode='max',
                                     save_best_only=args.save_best_only,
                                     ckpt_dir=ckpt_path,
                                     max_to_keep=args.max_to_keep)

    # Tensorboard Writer
    writer = SummaryWriter(log_dir=str(log_path))

    # Training loop. Please excuse my use of 1 based indexing here.
    logger.info('Beginning Training loop')
    for epoch in range(1, args.num_epochs + 1):
        # Start of training
        tic = time()
        train_loss_sum, train_top1_correct, train_top5_correct = \
            train_epoch(model, optimizer, loss_func, train_loader, device, epoch, args.verbose)

        toc = int(time() - tic)
        # Last step with small batch causes some inaccuracy but that is tolerable.
        train_epoch_loss = train_loss_sum.item() * args.batch_size / len(
            train_loader.dataset)
        train_epoch_top1_acc = train_top1_correct.item() / len(
            train_loader.dataset) * 100
        train_epoch_top5_acc = train_top5_correct.item() / len(
            train_loader.dataset) * 100

        logger.info(
            f'Epoch {epoch:03d} Training. loss: {train_epoch_loss:.4e}, top1 accuracy: '
            f'{train_epoch_top1_acc:.2f}%, top5 accuracy: {train_epoch_top5_acc:.2f}% Time: {toc}s'
        )

        # Writing to Tensorboard
        writer.add_scalar('train_epoch_loss', train_epoch_loss, epoch)
        writer.add_scalar('train_epoch_top1_acc', train_epoch_top1_acc, epoch)
        writer.add_scalar('train_epoch_top5_acc', train_epoch_top5_acc, epoch)

        # Start of evaluation
        tic = time()
        val_loss_sum, val_top1_correct, val_top5_correct = \
            eval_epoch(model, loss_func, val_loader, device, epoch, args.verbose)

        toc = int(time() - tic)
        val_epoch_loss = val_loss_sum.item() * args.batch_size / len(
            val_loader.dataset)
        val_epoch_top1_acc = val_top1_correct.item() / len(
            val_loader.dataset) * 100
        val_epoch_top5_acc = val_top5_correct.item() / len(
            val_loader.dataset) * 100

        logger.info(
            f'Epoch {epoch:03d} Validation. loss: {val_epoch_loss:.4e}, top1 accuracy: '
            f'{val_epoch_top1_acc:.2f}%, top5 accuracy: {val_epoch_top5_acc:.2f}%  Time: {toc}s'
        )

        # Writing to Tensorboard
        writer.add_scalar('val_epoch_loss', val_epoch_loss, epoch)
        writer.add_scalar('val_epoch_top1_acc', val_epoch_top1_acc, epoch)
        writer.add_scalar('val_epoch_top5_acc', val_epoch_top5_acc, epoch)
        for idx, group in enumerate(optimizer.param_groups, start=1):
            writer.add_scalar(f'learning_rate_{idx}', group['lr'], epoch)

        # Things to do after each epoch.
        scheduler.step(
        )  # Reduces LR at the designated times. Probably does not use 1 indexing like me.
        checkpointer.save(metric=val_epoch_top5_acc)
Exemplo n.º 12
0
def main():

    # Put these in args later.
    batch_size = 12
    num_workers = 8
    init_lr = 2E-4
    gpu = 0  # Set to None for CPU mode.
    num_epochs = 500
    verbose = False
    save_best_only = True
    max_to_keep = 100
    data_root = '/home/veritas/PycharmProjects/PA1/data'
    ckpt_root = '/home/veritas/PycharmProjects/PA1/checkpoints'
    log_root = '/home/veritas/PycharmProjects/PA1/logs'

    # Beginning session.
    run_number, run_name = initialize(ckpt_root)

    ckpt_path = Path(ckpt_root)
    ckpt_path.mkdir(exist_ok=True)
    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(log_root)
    log_path.mkdir(exist_ok=True)
    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    if (gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{gpu}')
    else:
        device = torch.device('cpu')

    # Do more fancy transforms later.
    transform = torchvision.transforms.ToTensor()

    train_dataset = torchvision.datasets.CIFAR100(data_root,
                                                  train=True,
                                                  transform=transform,
                                                  download=True)
    val_dataset = torchvision.datasets.CIFAR100(data_root,
                                                train=False,
                                                transform=transform,
                                                download=True)

    train_loader = DataLoader(train_dataset,
                              batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True)
    val_loader = DataLoader(val_dataset,
                            batch_size,
                            shuffle=False,
                            num_workers=num_workers,
                            pin_memory=True)

    # Define model, optimizer, etc.
    model = se_resnet50_cifar100().to(device, non_blocking=True)
    optimizer = optim.Adam(model.parameters(), lr=init_lr)
    # No softmax layer at the end necessary. Just need logits.
    loss_func = nn.CrossEntropyLoss().to(device, non_blocking=True)

    # Create checkpoint manager
    checkpointer = CheckpointManager(model, optimizer, ckpt_path,
                                     save_best_only, max_to_keep)

    # For recording data.
    previous_best = 0.  # Accuracy should improve.

    # Training loop. Please excuse my use of 1 based indexing here.
    logger.info('Beginning Training loop')
    for epoch in range(1, num_epochs + 1):
        # Start of training
        tic = time()
        train_loss_sum, train_top1_correct, train_top5_correct = \
            train_epoch(model, optimizer, loss_func, train_loader, device, epoch, verbose)

        toc = int(time() - tic)
        # Last step with small batch causes some inaccuracy but that is tolerable.
        train_epoch_loss = train_loss_sum.item() * batch_size / len(
            train_loader.dataset)
        train_epoch_top1_acc = train_top1_correct.item() / len(
            train_loader.dataset) * 100
        train_epoch_top5_acc = train_top5_correct.item() / len(
            train_loader.dataset) * 100

        msg = f'Epoch {epoch:03d} Training. loss: {train_epoch_loss:.4f}, ' \
            f'top1 accuracy: {train_epoch_top1_acc:.2f}%, top5 accuracy: {train_epoch_top5_acc:.2f}% Time: {toc}s'
        logger.info(msg)

        # Start of evaluation
        tic = time()
        val_loss_sum, val_top1_correct, val_top5_correct = \
            eval_epoch(model, loss_func, val_loader, device, epoch, verbose)

        toc = int(time() - tic)
        val_epoch_loss = val_loss_sum.item() * batch_size / len(
            val_loader.dataset)
        val_epoch_top1_acc = val_top1_correct.item() / len(
            val_loader.dataset) * 100
        val_epoch_top5_acc = val_top5_correct.item() / len(
            val_loader.dataset) * 100

        msg = f'Epoch {epoch:03d} Validation. loss: {val_epoch_loss:.4f}, ' \
            f'top1 accuracy: {val_epoch_top1_acc:.2f}%, top5 accuracy: {val_epoch_top5_acc:.2f}%  Time: {toc}s'
        logger.info(msg)

        if val_epoch_top5_acc > previous_best:  # Assumes larger metric is better.
            logger.info(
                f'Top 5 Validation Accuracy in Epoch {epoch} has improved from '
                f'{previous_best:.2f}% to {val_epoch_top5_acc:.2f}%')
            previous_best = val_epoch_top5_acc
            checkpointer.save(is_best=True)

        else:
            logger.info(
                f'Top 5 Validation Accuracy in Epoch {epoch} has not improved from the previous best epoch'
            )
            checkpointer.save(is_best=False)
Exemplo n.º 13
0
def main(argv):  # argv are non-flag parameters
    del argv
    start = time()
    tf.print('Tensorflow Engaged')
    run_number, run_name = initialize(FLAGS.ckpt_dir)

    log_path = Path(FLAGS.log_dir)
    log_path.mkdir(exist_ok=True)
    log_path = log_path / run_name
    log_path.mkdir(exist_ok=False)
    logger = get_logger(__name__)

    data_path = Path(FLAGS.data_dir)
    train_path = data_path / f'{FLAGS.challenge}_train'
    val_path = data_path / f'{FLAGS.challenge}_val'

    normalize = False

    def loss_func(labels, predictions):
        return 1 - tf.image.ssim(labels, predictions,
                                 max_val=15.0)  # Sort of heuristic method...

    train_dataset = HDF5Sequence(data_dir=train_path,
                                 batch_size=FLAGS.batch_size,
                                 training=True,
                                 normalize=normalize)
    val_dataset = HDF5Sequence(data_dir=val_path,
                               batch_size=FLAGS.batch_size,
                               training=False,
                               normalize=normalize)

    model = make_unet_model(scope='UNET', input_shape=(320, 320, 1))
    tf.keras.utils.plot_model(model,
                              to_file=log_path / f'model_{run_number:02d}.png',
                              show_shapes=True)
    model.summary()

    ckpt_path = Path(FLAGS.ckpt_dir) / run_name
    ckpt_path.mkdir(exist_ok=True)

    optimizer = tf.train.AdamOptimizer(
        learning_rate=FLAGS.lr)  # Can't change lr easily until TF2.0 update...
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)

    if FLAGS.restore_dir:
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.restore_dir)
        checkpoint.restore(latest_checkpoint).assert_nontrivial_match(
        ).assert_existing_objects_matched()
        print(f'Restored model from {latest_checkpoint}')

    manager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                         directory=str(ckpt_path),
                                         max_to_keep=FLAGS.max_to_keep)
    writer = tf.contrib.summary.create_file_writer(
        logdir=str(log_path))  # Graph display not possible in eager...

    model_config = json.loads(model.to_json())
    # Save FLAGS to a json file next to the tensorboard data
    save_dict_as_json(dict_data=FLAGS.flag_values_dict(),
                      log_dir=str(log_path),
                      save_name=f'{run_name}_FLAGS')
    save_dict_as_json(dict_data=model_config,
                      log_dir=str(log_path),
                      save_name=f'{run_name}_config')

    with writer.as_default(), tf.contrib.summary.always_record_summaries():
        train_and_eval(model=model,
                       optimizer=optimizer,
                       manager=manager,
                       train_dataset=train_dataset,
                       val_dataset=val_dataset,
                       num_epochs=FLAGS.num_epochs,
                       loss_func=loss_func,
                       save_best_only=FLAGS.save_best_only,
                       use_train_metrics=True,
                       use_val_metrics=True,
                       verbose=FLAGS.verbose,
                       max_images=FLAGS.max_images)

    finish = int(time() - start)
    logger.info(
        f'Finished Training model. Time: {finish // 3600:02d}hrs {(finish // 60) % 60}min {finish % 60}s'
    )
    return 0
Exemplo n.º 14
0
    def __init__(self,
                 args,
                 model,
                 optimizer,
                 train_loader,
                 val_loader,
                 post_processing,
                 c_loss,
                 metrics=None,
                 scheduler=None):

        multiprocessing.set_start_method(method='spawn')

        self.logger = get_logger(name=__name__,
                                 save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model,
                          nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(
            optimizer,
            optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        # I think this would be best practice.
        assert isinstance(
            post_processing,
            nn.Module), '`post_processing` must be a Pytorch Module.'

        # This is not a mistake. Pytorch implements loss functions as modules.
        assert isinstance(
            c_loss, nn.Module), '`c_loss` must be a callable Pytorch Module.'

        if metrics is not None:
            assert isinstance(
                metrics, Iterable
            ), '`metrics` must be an iterable, preferably a list or tuple.'
            for metric in metrics:
                assert callable(
                    metric), 'All metrics must be callable functions.'

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError(
                    '`scheduler` must be a Pytorch Learning Rate Scheduler.')

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.post_processing_func = post_processing
        self.c_loss_func = c_loss
        self.metrics = metrics
        self.scheduler = scheduler

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.writer = SummaryWriter(logdir=str(args.log_path))

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(
                len(self.val_loader.dataset) //
                (args.max_images * args.batch_size))

        self.checkpointer = CheckpointManager(
            model=self.model,
            optimizer=self.optimizer,
            mode='min',
            save_best_only=args.save_best_only,
            ckpt_dir=args.ckpt_path,
            max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.checkpointer.load(load_dir=args.prev_model_ckpt,
                                   load_optimizer=False)
Exemplo n.º 15
0
def main(args):

    ckpt_path = Path('checkpoints')
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    args.ckpt_path = ckpt_path

    log_path = Path('logs')
    log_path.mkdir(exist_ok=True)
    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    args.log_path = log_path

    save_dict_as_json(vars(save_dict_as_json),
                      log_dir=log_path,
                      save_name=run_name)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        args.device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        args.device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Create Datasets
    train_dataset = TrainingDataset(args.train_root,
                                    transform=slice_normalize_and_clip,
                                    single_coil=False)
    val_dataset = TrainingDataset(  # Shrink val time by half. Consistent comparison remains.
        args.val_root,
        transform=slice_normalize_and_clip,
        single_coil=False,
        use_double=False,
        seed=9872)

    train_loader = DataLoader(train_dataset,
                              args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)
    val_loader = DataLoader(val_dataset,
                            args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers,
                            pin_memory=True)

    # Define model, optimizer, and loss function.
    model = UnetModel(in_chans=1, out_chans=1, chans=32,
                      num_pool_layers=4).to(args.device, non_blocking=True)
    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    loss_func = nn.L1Loss(reduction='mean').to(args.device, non_blocking=True)

    checkpointer = CheckpointManager(model, optimizer, args.save_best_only,
                                     ckpt_path, args.max_to_keep)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='min',
                                                     factor=0.1,
                                                     patience=5,
                                                     verbose=True,
                                                     cooldown=0,
                                                     min_lr=1E-7)

    writer = SummaryWriter(log_dir=str(log_path))

    example_inputs = torch.ones(size=(args.batch_size, 1, 320,
                                      320)).to(args.device)
    writer.add_graph(model=model, input_to_model=example_inputs, verbose=False)

    previous_best = np.inf

    # Training loop. Please excuse my use of 1 based indexing here.
    logger.info('Beginning Training loop')

    for epoch in range(1, args.num_epochs + 1):
        tic = time()
        model.train()
        torch.autograd.set_grad_enabled(True)
        logger.info(f'Beginning training for Epoch {epoch:03d}')

        # Reset to 0. There must be a better way though...
        train_loss_sum = torch.as_tensor(0.)

        tot_len = len(train_loader.dataset) // args.batch_size
        for idx, (ds_imgs, labels) in tqdm(enumerate(train_loader, start=1),
                                           total=tot_len):

            ds_imgs = ds_imgs.to(args.device).unsqueeze(dim=1)
            labels = labels.to(args.device).unsqueeze(dim=1)

            optimizer.zero_grad()
            pred_residuals = model(ds_imgs)  # Predicted residuals are outputs.

            recons = pred_residuals + ds_imgs

            step_loss = loss_func(recons, labels)
            step_loss.backward()
            optimizer.step()

            with torch.no_grad(
            ):  # Necessary for metric calculation without errors due to gradient accumulation.
                train_loss_sum += step_loss

                if args.verbose:
                    print(
                        f'Training loss Epoch {epoch:03d} Step {idx:03d}: {step_loss.item()}'
                    )

        else:
            toc = int(time() - tic)
            # Last step with small batch causes some inaccuracy but that is tolerable.
            epoch_loss = train_loss_sum.item() * args.batch_size / len(
                train_loader.dataset)

            logger.info(
                f'Epoch {epoch:03d} Training. loss: {float(epoch_loss):.4f}, Time: {toc // 60}min {toc % 60}s'
            )

            writer.add_scalar('train_loss', epoch_loss, global_step=epoch)

        tic = time()
        model.eval()
        torch.autograd.set_grad_enabled(False)
        logger.info(f'Beginning validation for Epoch {epoch:03d}')

        # Reset to 0. There must be a better way though...
        val_loss_sum = torch.as_tensor(0.)

        # Need to return residual and labels for residual learning.
        tot_len = len(val_loader.dataset) // args.batch_size
        for idx, (ds_imgs, labels) in tqdm(enumerate(val_loader, start=1),
                                           total=tot_len):

            ds_imgs = ds_imgs.to(args.device).unsqueeze(dim=1)
            labels = labels.to(args.device).unsqueeze(dim=1)
            pred_residuals = model(ds_imgs)
            recons = pred_residuals + ds_imgs

            step_loss = loss_func(recons, labels)
            val_loss_sum += step_loss

            if args.verbose:
                print(
                    f'Validation loss Epoch {epoch:03d} Step {idx:03d}: {step_loss.item()}'
                )

            if args.max_imgs:
                pass  # TODO: Implement saving images to TensorBoard.

        else:
            toc = int(time() - tic)
            epoch_loss = val_loss_sum.item() * args.batch_size / len(
                val_loader.dataset)

            logger.info(
                f'Epoch {epoch:03d} Validation. loss: {float(epoch_loss):.4f} Time: {toc // 60}min {toc % 60}s'
            )

            writer.add_scalar('val_loss', epoch_loss, global_step=epoch)

            scheduler.step(metrics=epoch_loss,
                           epoch=epoch)  # Changes optimizer lr.

        # Checkpoint generation. Only implemented for single GPU models, not multi-gpu models.
        # All comparisons are done with python numbers, not tensors.
        if epoch_loss < previous_best:  # Assumes smaller metric is better.
            logger.info(
                f'Loss in Epoch {epoch} has improved from {float(previous_best):.4f} to {float(epoch_loss):.4f}'
            )
            previous_best = epoch_loss
            checkpointer.save(is_best=True)

        else:
            logger.info(
                f'Loss in Epoch {epoch} has not improved from the previous best epoch'
            )
            checkpointer.save(
                is_best=False)  # No save if save_best_only is True.
Exemplo n.º 16
0
def train_and_eval(model,
                   optimizer,
                   manager,
                   train_dataset,
                   val_dataset,
                   num_epochs,
                   loss_func,
                   save_best_only=True,
                   use_train_metrics=False,
                   use_val_metrics=True,
                   verbose=True,
                   max_images=36):

    logger = get_logger(__name__)

    prev_loss = 2**30

    for epoch in range(1, num_epochs + 1):
        # Training
        tic = time()
        logger.info(f'\nStarting Epoch {epoch:03d} Training')
        train_epoch_loss, train_epoch_metrics = \
            train_epoch(model=model, optimizer=optimizer, dataset=train_dataset, loss_func=loss_func,
                        epoch=epoch, use_metrics=use_train_metrics, verbose=verbose)

        # After Epoch training is over.
        toc = int(time() - tic)
        logger.info(
            f'Epoch {epoch:03d} Training Finished. Time: {toc // 60}min {toc % 60}s.'
        )

        tf.contrib.summary.scalar(name='train_epoch_loss',
                                  tensor=train_epoch_loss,
                                  step=epoch)
        logger.info(f'Epoch Training loss: {float(train_epoch_loss):.4f}')

        if use_train_metrics:
            logger.info(f'Epoch Training Metrics:')
            for idx, metric in enumerate(train_epoch_metrics, start=1):
                tf.contrib.summary.scalar(name=f'train_metric_{idx}',
                                          tensor=metric,
                                          step=epoch)
                logger.info(f'Train Metric {idx}: {float(metric):.4f}')

        # Validation
        tic = time()
        logger.info(f'\nStarting Epoch {epoch:03d} Validation')
        val_epoch_loss, val_epoch_metrics = \
            val_epoch(model=model, dataset=val_dataset, loss_func=loss_func, epoch=epoch,
                      max_images=max_images, verbose=verbose, use_metrics=use_val_metrics)

        # After Epoch validation is over.
        toc = int(time() - tic)

        tf.contrib.summary.scalar(name='val_epoch_loss',
                                  tensor=val_epoch_loss,
                                  step=epoch)
        logger.info(
            f'Epoch {epoch:03d} Validation Finished. Time: {toc // 60}min {toc % 60}s.'
        )
        logger.info(f'Epoch Validation loss: {float(val_epoch_loss):.4f}')

        if use_val_metrics:
            logger.info(f'Epoch Validation Metrics:')
            for idx, metric in enumerate(val_epoch_metrics, start=1):
                tf.contrib.summary.scalar(name=f'val_metric_{idx}',
                                          tensor=metric,
                                          step=epoch)
                logger.info(f'Validation Metric {idx}: {float(metric):.4f}')

        # Checkpoint the Epoch if there has been improvement.  # Not possible when the loss function keeps changing...
        if val_epoch_loss < prev_loss:
            prev_loss = val_epoch_loss
            logger.info('Validation loss has improved from previous epoch')
            logger.info(f'Last checkpoint file: {manager.latest_checkpoint}')
            manager.save(checkpoint_number=epoch)
        else:
            logger.info('Validation loss has not improved from previous epoch')
            logger.info(f'Previous minimum loss: {float(prev_loss):.4f}')
            if not save_best_only:
                manager.save(checkpoint_number=epoch)
Exemplo n.º 17
0
def train_img(args):

    # Maybe move this to args later.
    train_method = 'K2CI'

    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2**args.num_pool_layers

    mask_func = MaskFunc(args.center_fractions, args.accelerations)

    data_prefetch = Prefetch2Device(device)

    input_train_transform = TrainPreProcessK(mask_func,
                                             args.challenge,
                                             args.device,
                                             use_seed=False,
                                             divisor=divisor)
    input_val_transform = TrainPreProcessK(mask_func,
                                           args.challenge,
                                           args.device,
                                           use_seed=True,
                                           divisor=divisor)

    # DataLoaders
    train_loader, val_loader = create_custom_data_loaders(
        args, transform=data_prefetch)

    losses = dict(
        cmg_loss=nn.MSELoss(reduction='mean'),
        # img_loss=L1CSSIM7(reduction='mean', alpha=0.5)
        img_loss=CSSIM(filter_size=7, reduction='mean'))

    output_transform = OutputReplaceTransformK()

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    model = UnetASE(in_chans=data_chans,
                    out_chans=data_chans,
                    ext_chans=args.chans,
                    chans=args.chans,
                    num_pool_layers=args.num_pool_layers,
                    min_ext_size=args.min_ext_size,
                    max_ext_size=args.max_ext_size,
                    use_ext_bias=args.use_ext_bias,
                    use_att=False).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)

    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.lr_red_epoch,
                                          gamma=args.lr_red_rate)

    trainer = ModelTrainerK2CI(args, model, optimizer, train_loader,
                               val_loader, input_train_transform,
                               input_val_transform, output_transform, losses,
                               scheduler)

    trainer.train_model()
Exemplo n.º 18
0
    def __init__(self,
                 args,
                 generator,
                 discriminator,
                 gen_optim,
                 disc_optim,
                 train_loader,
                 val_loader,
                 loss_funcs,
                 gen_scheduler=None,
                 disc_scheduler=None):

        self.logger = get_logger(name=__name__,
                                 save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(generator, nn.Module) and isinstance(discriminator, nn.Module), \
            '`generator` and `discriminator` must be Pytorch Modules.'
        assert isinstance(gen_optim, optim.Optimizer) and isinstance(disc_optim, optim.Optimizer), \
            '`gen_optim` and `disc_optim` must be Pytorch Optimizers.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        loss_funcs = nn.ModuleDict(
            loss_funcs
        )  # Expected to be a dictionary with names and loss functions.

        if gen_scheduler is not None:
            if isinstance(gen_scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_gen_scheduler = True
            elif isinstance(gen_scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_gen_scheduler = False
            else:
                raise TypeError(
                    '`gen_scheduler` must be a Pytorch Learning Rate Scheduler.'
                )

        if disc_scheduler is not None:
            if isinstance(disc_scheduler,
                          optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_disc_scheduler = True
            elif isinstance(disc_scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_disc_scheduler = False
            else:
                raise TypeError(
                    '`disc_scheduler` must be a Pytorch Learning Rate Scheduler.'
                )

        self.generator = generator
        self.discriminator = discriminator
        self.gen_optim = gen_optim
        self.disc_optim = disc_optim
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.loss_funcs = loss_funcs
        self.gen_scheduler = gen_scheduler
        self.disc_scheduler = disc_scheduler
        self.device = args.device
        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.writer = SummaryWriter(str(args.log_path))

        self.recon_lambda = torch.tensor(args.recon_lambda,
                                         dtype=torch.float32,
                                         device=args.device)
        self.lambda_gp = torch.tensor(args.lambda_gp,
                                      dtype=torch.float32,
                                      device=args.device)

        # This will work best if batch size is 1, as is recommended. I don't know whether this generalizes.
        self.target_real = torch.tensor(1,
                                        dtype=torch.float32,
                                        device=args.device)
        self.target_fake = torch.tensor(0,
                                        dtype=torch.float32,
                                        device=args.device)

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(
                len(self.val_loader.dataset) //
                (args.max_images * args.batch_size))

        self.gen_checkpoint_manager = CheckpointManager(
            model=self.generator,
            optimizer=self.gen_optim,
            mode='min',
            save_best_only=args.save_best_only,
            ckpt_dir=args.ckpt_path / 'Generator',
            max_to_keep=args.max_to_keep)

        self.disc_checkpoint_manager = CheckpointManager(
            model=self.discriminator,
            optimizer=self.disc_optim,
            mode='min',
            save_best_only=args.save_best_only,
            ckpt_dir=args.ckpt_path / 'Discriminator',
            max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('gen_prev_model_ckpt'):
            self.gen_checkpoint_manager.load(load_dir=args.gen_prev_model_ckpt,
                                             load_optimizer=False)

        if vars(args).get('disc_prev_model_ckpt'):
            self.disc_checkpoint_manager.load(
                load_dir=args.disc_prev_model_ckpt, load_optimizer=False)
Exemplo n.º 19
0
def train_k2i(args):

    # Maybe move this to args later.
    train_method = 'W2I'  # Weighted K-space to real-valued image.

    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Input transforms. These are on a per-slice basis.
    # UNET architecture requires that all inputs be dividable by some power of 2.
    divisor = 2**args.num_pool_layers

    if args.random_sampling:
        mask_func = MaskFunc(args.center_fractions, args.accelerations)
    else:
        mask_func = UniformMaskFunc(args.center_fractions, args.accelerations)

    # This is optimized for SSD storage.
    # Sending to device should be inside the input transform for optimal performance on HDD.
    data_prefetch = Prefetch2Device(device)

    input_train_transform = WeightedPreProcessK(mask_func,
                                                args.challenge,
                                                device,
                                                use_seed=False,
                                                divisor=divisor)
    input_val_transform = WeightedPreProcessK(mask_func,
                                              args.challenge,
                                              device,
                                              use_seed=True,
                                              divisor=divisor)

    # DataLoaders
    train_loader, val_loader = create_custom_data_loaders(
        args, transform=data_prefetch)

    losses = dict(img_loss=nn.L1Loss(reduction='mean')
                  # img_loss=L1CSSIM7(reduction='mean', alpha=args.alpha)
                  )

    output_transform = WeightedReplacePostProcessK()

    data_chans = 2 if args.challenge == 'singlecoil' else 30  # Multicoil has 15 coils with 2 for real/imag

    model = UNetSkipGN(in_chans=data_chans,
                       out_chans=data_chans,
                       chans=args.chans,
                       num_pool_layers=args.num_pool_layers,
                       num_groups=args.num_groups,
                       pool_type=args.pool_type,
                       use_skip=args.use_skip,
                       use_att=args.use_att,
                       reduction=args.reduction,
                       use_gap=args.use_gap,
                       use_gmp=args.use_gmp).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = ModelTrainerK2I(args, model, optimizer, train_loader, val_loader,
                              input_train_transform, input_val_transform,
                              output_transform, losses, scheduler)

    trainer.train_model()
Exemplo n.º 20
0
def train_xnet(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_root)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_root)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__)

    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    arguments = vars(
        args)  # Placed here for backward compatibility and convenience.
    args.center_fractions_train = arguments.get(
        'center_fractions_train', arguments.get('center_fractions'))
    args.center_fractions_val = arguments.get(
        'center_fractions_val', arguments.get('center_fractions'))
    args.accelerations_train = arguments.get('accelerations_train',
                                             arguments.get('accelerations'))
    args.accelerations_val = arguments.get('accelerations_val',
                                           arguments.get('accelerations'))

    if args.random_sampling:
        train_mask_func = RandomMaskFunc(args.center_fractions_train,
                                         args.accelerations_train)
        val_mask_func = RandomMaskFunc(args.center_fractions_val,
                                       args.accelerations_val)
    else:
        train_mask_func = UniformMaskFunc(args.center_fractions_train,
                                          args.accelerations_train)
        val_mask_func = UniformMaskFunc(args.center_fractions_val,
                                        args.accelerations_val)

    input_train_transform = PreProcessXNet(mask_func=train_mask_func,
                                           challenge=args.challenge,
                                           device=device,
                                           augment_data=args.augment_data,
                                           use_seed=False,
                                           crop_center=args.crop_center)
    input_val_transform = PreProcessXNet(mask_func=val_mask_func,
                                         challenge=args.challenge,
                                         device=device,
                                         augment_data=False,
                                         use_seed=True,
                                         crop_center=args.crop_center)

    output_train_transform = PostProcessXNet(challenge=args.challenge)
    output_val_transform = PostProcessXNet(challenge=args.challenge)

    # DataLoaders
    train_loader, val_loader = create_prefetch_data_loaders(args)

    losses = dict(
        phase_loss=nn.MSELoss(),
        # img_loss=SSIMLoss(filter_size=7).to(device=device)
        img_loss=LogSSIMLoss(filter_size=5).to(device=device),
        # img_loss=nn.L1Loss()
        x_loss=AlignmentLoss())

    data_chans = 1 if args.challenge == 'singlecoil' else 15  # Multicoil has 15 coils with 2 for real/imag

    model = XNet(in_chans=data_chans,
                 out_chans=data_chans,
                 chans=args.chans,
                 num_pool_layers=args.num_pool_layers,
                 num_depth_blocks=args.num_depth_blocks,
                 dilation=args.dilation,
                 res_scale=args.res_scale).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.lr_red_epochs,
                                               gamma=args.lr_red_rate)

    trainer = XNetModelTrainer(args, model, optimizer, train_loader,
                               val_loader, input_train_transform,
                               input_val_transform, output_train_transform,
                               output_val_transform, losses, scheduler)

    try:
        trainer.train_model()
    except KeyboardInterrupt:
        trainer.writer.close()
        logger.warning(
            f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.'
        )
Exemplo n.º 21
0
def train_gan_model(args):
    # Creating checkpoint and logging directories, as well as the run name.
    ckpt_path = Path(args.ckpt_dir)
    ckpt_path.mkdir(exist_ok=True)

    ckpt_path = ckpt_path / args.train_method
    ckpt_path.mkdir(exist_ok=True)

    run_number, run_name = initialize(ckpt_path)

    ckpt_path = ckpt_path / run_name
    ckpt_path.mkdir(exist_ok=True)

    log_path = Path(args.log_dir)
    log_path.mkdir(exist_ok=True)

    log_path = log_path / args.train_method
    log_path.mkdir(exist_ok=True)

    log_path = log_path / run_name
    log_path.mkdir(exist_ok=True)

    logger = get_logger(name=__name__, save_file=log_path / run_name)

    # Assignment inside running code appears to work.
    if (args.gpu is not None) and torch.cuda.is_available():
        device = torch.device(f'cuda:{args.gpu}')
        logger.info(f'Using GPU {args.gpu} for {run_name}')
    else:
        device = torch.device('cpu')
        logger.info(f'Using CPU for {run_name}')

    # Please note that many objects (such as Path objects) cannot be serialized to json files.
    save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name)

    # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible.
    args.run_number = run_number
    args.run_name = run_name
    args.ckpt_path = ckpt_path
    args.log_path = log_path
    args.device = device

    train_transform = InputTrainTransform(is_training=True)
    val_transform = InputTrainTransform(is_training=False)

    # DataLoaders
    train_loader, val_loader = create_data_loaders(args, train_transform,
                                                   val_transform)

    # Loss Function and output post-processing functions.
    gan_loss_func = nn.BCELoss(reduction='mean')
    recon_loss_func = L1CSSIM(l1_weight=args.l1_weight,
                              default_range=12,
                              filter_size=7,
                              reduction='mean')
    # recon_loss_func = nn.L1Loss(reduction='mean')
    loss_funcs = {
        'gan_loss_func': gan_loss_func,
        'recon_loss_func': recon_loss_func
    }

    # Define model.
    generator = UnetModel(1, 1, args.chans, args.num_pool_layers,
                          drop_prob=0).to(device)
    discriminator = SimpleCNN(chans=args.chans).to(device)

    gen_optim = optim.Adam(params=generator.parameters(), lr=args.init_lr)
    disc_optim = optim.Adam(params=discriminator.parameters(), lr=args.init_lr)

    gen_scheduler = optim.lr_scheduler.StepLR(gen_optim,
                                              step_size=args.step_size,
                                              gamma=args.lr_reduction_rate)
    disc_scheduler = optim.lr_scheduler.StepLR(disc_optim,
                                               step_size=args.step_size,
                                               gamma=args.lr_reduction_rate)

    trainer = GANModelTrainer(args, generator, discriminator, gen_optim,
                              disc_optim, train_loader, val_loader, loss_funcs,
                              gen_scheduler, disc_scheduler)

    trainer.train_model()