def __init__(self, args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_transform, losses, scheduler=None): # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs. if multiprocessing.get_start_method(allow_none=True) is None: multiprocessing.set_start_method(method='spawn') self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name) # Checking whether inputs are correct. assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.' assert isinstance(optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.' assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \ '`train_loader` and `val_loader` must be Pytorch DataLoader objects.' assert callable(input_train_transform) and callable(input_val_transform), \ 'input_transforms must be callable functions.' # I think this would be best practice. assert isinstance(output_transform, nn.Module), '`output_transform` must be a Pytorch Module.' # 'losses' is expected to be a dictionary. # Even composite losses should be a single loss module with multiple outputs. losses = nn.ModuleDict(losses) if scheduler is not None: if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.metric_scheduler = True elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): self.metric_scheduler = False else: raise TypeError('`scheduler` must be a Pytorch Learning Rate Scheduler.') # Display interval of 0 means no display of validation images on TensorBoard. if args.max_images <= 0: self.display_interval = 0 else: self.display_interval = int(len(val_loader.dataset) // (args.max_images * args.batch_size)) self.manager = CheckpointManager(model, optimizer, mode='min', save_best_only=args.save_best_only, ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep) # loading from checkpoint if specified. if vars(args).get('prev_model_ckpt'): self.manager.load(load_dir=args.prev_model_ckpt, load_optimizer=False) self.model = model self.optimizer = optimizer self.train_loader = train_loader self.val_loader = val_loader self.input_train_transform = input_train_transform self.input_val_transform = input_val_transform self.output_transform = output_transform self.losses = losses self.scheduler = scheduler self.verbose = args.verbose self.num_epochs = args.num_epochs self.smoothing_factor = args.smoothing_factor self.use_slice_metrics = args.use_slice_metrics self.writer = SummaryWriter(str(args.log_path))
def train_img_to_rss(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) arguments = vars( args) # Placed here for backward compatibility and convenience. args.center_fractions_train = arguments.get( 'center_fractions_train', arguments.get('center_fractions')) args.center_fractions_val = arguments.get( 'center_fractions_val', arguments.get('center_fractions')) args.accelerations_train = arguments.get('accelerations_train', arguments.get('accelerations')) args.accelerations_val = arguments.get('accelerations_val', arguments.get('accelerations')) if args.random_sampling: train_mask_func = RandomMaskFunc(args.center_fractions_train, args.accelerations_train) val_mask_func = RandomMaskFunc(args.center_fractions_val, args.accelerations_val) else: train_mask_func = UniformMaskFunc(args.center_fractions_train, args.accelerations_train) val_mask_func = UniformMaskFunc(args.center_fractions_val, args.accelerations_val) input_train_transform = PreProcessRSS(mask_func=train_mask_func, challenge=args.challenge, device=device, augment_data=args.augment_data, use_seed=False, fat_info=args.fat_info) input_val_transform = PreProcessRSS(mask_func=val_mask_func, challenge=args.challenge, device=device, augment_data=False, use_seed=True, fat_info=args.fat_info) output_train_transform = PostProcessRSS(challenge=args.challenge, residual_rss=args.residual_rss) output_val_transform = PostProcessRSS(challenge=args.challenge, residual_rss=args.residual_rss) # DataLoaders train_loader, val_loader = create_prefetch_data_loaders(args) losses = dict(rss_loss=SSIMLoss(filter_size=7).to(device=device)) in_chans = 16 if args.fat_info else 15 model = UNet(in_chans=in_chans, out_chans=1, chans=args.chans, num_pool_layers=args.num_pool_layers, num_res_groups=args.num_res_groups, num_res_blocks_per_group=args.num_res_blocks_per_group, growth_rate=args.growth_rate, num_dense_layers=args.num_dense_layers, use_dense_ca=args.use_dense_ca, num_res_layers=args.num_res_layers, res_scale=args.res_scale, reduction=args.reduction, thick_base=args.thick_base).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=args.milestones, gamma=args.lr_red_rate) trainer = ModelTrainerRSS(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_train_transform, output_val_transform, losses, scheduler) try: trainer.train_model_concat() except KeyboardInterrupt: trainer.writer.close() logger.warning('Closing summary writer due to KeyboardInterrupt.')
def train_cmg_to_img_direct(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # # UNET architecture requires that all inputs be dividable by some power of 2. # divisor = 2 ** args.num_pool_layers arguments = vars( args) # Placed here for backward compatibility and convenience. args.center_fractions_train = arguments.get( 'center_fractions_train', arguments.get('center_fractions')) args.center_fractions_val = arguments.get( 'center_fractions_val', arguments.get('center_fractions')) args.accelerations_train = arguments.get('accelerations_train', arguments.get('accelerations')) args.accelerations_val = arguments.get('accelerations_val', arguments.get('accelerations')) if args.random_sampling: train_mask_func = RandomMaskFunc(args.center_fractions_train, args.accelerations_train) val_mask_func = RandomMaskFunc(args.center_fractions_val, args.accelerations_val) else: train_mask_func = UniformMaskFunc(args.center_fractions_train, args.accelerations_train) val_mask_func = UniformMaskFunc(args.center_fractions_val, args.accelerations_val) input_train_transform = PreProcessCMGIMG(mask_func=train_mask_func, challenge=args.challenge, device=device, augment_data=args.augment_data, use_seed=False, crop_center=args.crop_center) input_val_transform = PreProcessCMGIMG(mask_func=val_mask_func, challenge=args.challenge, device=device, augment_data=False, use_seed=True, crop_center=args.crop_center) output_train_transform = PostProcessCMGIMG(challenge=args.challenge, output_mode='img') output_val_transform = PostProcessCMGIMG(challenge=args.challenge, output_mode='img') # DataLoaders train_loader, val_loader = create_prefetch_data_loaders(args) losses = dict( # img_loss=SSIMLoss(filter_size=7).to(device=device) # img_loss=LogSSIMLoss(filter_size=7).to(device=device) img_loss=nn.L1Loss() # img_loss=L1SSIMLoss(filter_size=7, l1_ratio=args.l1_ratio).to(device=device) ) # data_chans = 1 if args.challenge == 'singlecoil' else 15 model = UNet(in_chans=45, out_chans=15, chans=args.chans, num_pool_layers=args.num_pool_layers, num_depth_blocks=args.num_depth_blocks, res_scale=args.res_scale, use_residual=False, use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = ModelTrainerI2I(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_train_transform, output_val_transform, losses, scheduler) try: trainer.train_model() except KeyboardInterrupt: trainer.writer.close() logger.warning('Closing summary writer due to KeyboardInterrupt.')
def train_complex(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # Input transforms. These are on a per-slice basis. # UNET architecture requires that all inputs be dividable by some power of 2. divisor = 2**args.num_pool_layers if args.random_sampling: mask_func = RandomMaskFunc(args.center_fractions, args.accelerations) else: mask_func = UniformMaskFunc(args.center_fractions, args.accelerations) data_prefetch = Prefetch2Device(device) if args.train_method == 'WS2C': # Semi-k-space learning. weight_func = SemiDistanceWeight(weight_type=args.weight_type) input_train_transform = PreProcessWSK(mask_func, weight_func, args.challenge, device, use_seed=False, divisor=divisor) input_val_transform = PreProcessWSK(mask_func, weight_func, args.challenge, device, use_seed=True, divisor=divisor) output_transform = WeightedReplacePostProcessSemiK( weighted=True, replace=args.replace) elif args.train_method == 'WK2C': # k-space learning. weight_func = TiltedDistanceWeight(weight_type=args.weight_type, y_scale=args.y_scale) input_train_transform = PreProcessWK(mask_func, weight_func, args.challenge, device, use_seed=False, divisor=divisor) input_val_transform = PreProcessWK(mask_func, weight_func, args.challenge, device, use_seed=True, divisor=divisor) output_transform = WeightedReplacePostProcessK(weighted=True, replace=args.replace) else: raise NotImplementedError('Invalid train method!') # DataLoaders train_loader, val_loader = create_custom_data_loaders( args, transform=data_prefetch) losses = dict(cmg_loss=nn.MSELoss(reduction='mean')) data_chans = 2 if args.challenge == 'singlecoil' else 30 # Multicoil has 15 coils with 2 for real/imag # model = UNetModel( # in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, # num_groups=args.num_groups, use_residual=args.use_residual, pool_type=args.pool_type, use_skip=args.use_skip, # use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp, # use_sa=args.use_sa, sa_kernel_size=args.sa_kernel_size, sa_dilation=args.sa_dilation, use_cap=args.use_cap, # use_cmp=args.use_cmp).to(device) model = UNetModelKSSE(in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, num_groups=args.num_groups, use_residual=args.use_residual, pool_type=args.pool_type, use_skip=args.use_skip, min_ext_size=args.min_ext_size, max_ext_size=args.max_ext_size, ext_mode=args.ext_mode, use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp, use_sa=args.use_sa, sa_kernel_size=args.sa_kernel_size, sa_dilation=args.sa_dilation, use_cap=args.use_cap, use_cmp=args.use_cmp).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = ModelTrainerCOMPLEX(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_transform, losses, scheduler) trainer.train_model()
def train_cmg_to_img(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) if args.random_sampling: mask_func = RandomMaskFunc(args.center_fractions, args.accelerations) else: mask_func = UniformMaskFunc(args.center_fractions, args.accelerations) input_train_transform = PreProcessCMG(mask_func, args.challenge, device, augment_data=args.augment_data, use_seed=False, crop_center=args.crop_center) input_val_transform = PreProcessCMG(mask_func, args.challenge, device, augment_data=False, use_seed=True, crop_center=args.crop_center) output_train_transform = PostProcessCMG(challenge=args.challenge, residual_acs=args.residual_acs) output_val_transform = PostProcessCMG(challenge=args.challenge, residual_acs=args.residual_acs) # DataLoaders train_loader, val_loader = create_prefetch_data_loaders(args) losses = dict( img_loss=LogSSIMLoss(filter_size=7).to(device) # img_loss=SSIMLoss(filter_size=7).to(device=device) # img_loss=nn.L1Loss() ) # model = UNet( # in_chans=30, out_chans=30, chans=args.chans, num_pool_layers=args.num_pool_layers, num_groups=args.num_groups, # negative_slope=args.negative_slope, use_residual=args.use_residual, interp_mode=args.interp_mode, # use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device) model = UNet(in_chans=30, out_chans=30, chans=args.chans, num_pool_layers=args.num_pool_layers, num_depth_blocks=args.num_depth_blocks, use_residual=args.use_residual, use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = ModelTrainerIMG(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_train_transform, output_val_transform, losses, scheduler) try: trainer.train_model() except KeyboardInterrupt: trainer.writer.close() logger.warning('Closing summary writer due to KeyboardInterrupt.')
def main(argv): # argv are non-flag parameters del argv start = time() tf.print('Tensorflow Engaged') run_number, run_name = initialize(FLAGS.ckpt_dir) log_path = Path(FLAGS.log_dir) log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=False) logger = get_logger(__name__) data_path = Path(FLAGS.data_dir) train_path = data_path / f'{FLAGS.challenge}_train' val_path = data_path / f'{FLAGS.challenge}_val' train_dataset = HDF5Sequence(data_dir=train_path, batch_size=FLAGS.batch_size, training=True, as_tensors=False) val_dataset = HDF5Sequence(data_dir=val_path, batch_size=FLAGS.batch_size, training=False, as_tensors=False) model = make_unet_model(scope='UNET', input_shape=(320, 320, 1)) # multi_model = tf.keras.utils.multi_gpu_model(model, gpus=2) multi_model = model tf.keras.utils.plot_model(model, to_file=log_path / f'model_{run_number:02d}.png', show_shapes=True) model.summary() ckpt_path = Path(FLAGS.ckpt_dir) / run_name ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / '{epoch:04d}.ckpt' ckpt_path = str(ckpt_path) optimizer = tf.keras.optimizers.Adam(lr=FLAGS.lr) checkpoint = ModelCheckpoint(filepath=ckpt_path, verbose=1, save_best_only=True) visualizer = TensorBoard(log_dir=log_path, batch_size=FLAGS.batch_size) model_config = json.loads(model.to_json()) # Save FLAGS to a json file next to the tensorboard data save_dict_as_json(dict_data=FLAGS.flag_values_dict(), log_dir=str(log_path), save_name=f'{run_name}_FLAGS') save_dict_as_json(dict_data=model_config, log_dir=str(log_path), save_name=f'{run_name}_config') multi_model.compile( optimizer=optimizer, loss='mae', metrics=[batch_ssim, batch_psnr, batch_msssim, batch_nmse]) multi_model.fit_generator(train_dataset, epochs=FLAGS.num_epochs, verbose=1, callbacks=[checkpoint, visualizer], validation_data=val_dataset, workers=4, use_multiprocessing=True) finish = int(time() - start) logger.info( f'Finished Training model. Time: {finish // 3600:02d}hrs {(finish // 60) % 60}min {finish % 60}s' ) return 0
def main(): batch_size = 8 num_workers = 4 init_lr = 2E-4 gpu = 1 # Set to None for CPU mode. num_epochs = 500 verbose = False save_best_only = True data_root = '/home/veritas/PycharmProjects/PA1/data' log_root = '/home/veritas/PycharmProjects/PA1/logs' ckpt_root = '/home/veritas/PycharmProjects/PA1/checkpoints' ckpt_path = Path(ckpt_root) ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(log_root) log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) if (gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{gpu}') else: device = torch.device('cpu') # Do more fancy transforms later. transform = torchvision.transforms.ToTensor() train_dataset = torchvision.datasets.CIFAR100(data_root, train=True, transform=transform, download=True) val_dataset = torchvision.datasets.CIFAR100(data_root, train=False, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) # Define model, optimizer, etc. model = torchvision.models.resnet50(pretrained=False, num_classes=100).to(device, non_blocking=True) optimizer = optim.Adam(model.parameters(), lr=init_lr) # No softmax layer at the end necessary. Just need logits. loss_func = CrossEntropyLoss().to(device, non_blocking=True) # Getting data collection train_loss_sum = torch.as_tensor(0.) val_loss_sum = torch.as_tensor(0.) train_top1_correct = torch.as_tensor(0) val_top1_correct = torch.as_tensor(0) train_top5_correct = torch.as_tensor(0) val_top5_correct = torch.as_tensor(0) previous_best = 0. # Training loop. Please excuse my use of 1 based indexing here. logger.info('Beginning Training loop') for epoch in range(1, num_epochs + 1): tic = time() model.train() torch.autograd.set_grad_enabled(True) for idx, (images, labels) in enumerate(train_loader, start=1): images = images.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) optimizer.zero_grad() preds = model(images) # Pytorch uses (input, target) ordering. Their shapes are also different. step_loss = loss_func(preds, labels) step_loss.backward() optimizer.step() with torch.no_grad( ): # Necessary for metric calculation without errors due to gradient accumulation. train_loss_sum += step_loss top1 = torch.argmax(preds, dim=1) _, top5 = torch.topk(preds, k=5) # Get only the indices. train_top1_correct += torch.sum(torch.eq(top1, labels)) train_top5_correct += torch.sum( torch.eq(top5, labels.unsqueeze(-1))) # This is probably right. if verbose: print( f'Training loss Epoch {epoch:03d} Step {idx:03d}: {step_loss.item()}' ) else: toc = int(time() - tic) # Last step with small batch causes some inaccuracy but that is tolerable. epoch_loss = train_loss_sum.item() * batch_size / len( train_loader.dataset) epoch_top1_acc = train_top1_correct.item() / len( train_loader.dataset) * 100 epoch_top5_acc = train_top5_correct.item() / len( train_loader.dataset) * 100 msg = f'loss: {epoch_loss:.4f}, top1 accuracy: {epoch_top1_acc:.2f}%, top5 accuracy: {epoch_top5_acc:.2f}%' logger.info(f'Epoch {epoch:03d} Training. {msg} Time: {toc}s') # writer.add_scalar('train_loss', epoch_loss, global_step=epoch) # writer.add_scalar('train_acc', epoch_top1_acc, global_step=epoch) # Reset to 0. There must be a better way though... train_loss_sum = torch.as_tensor(0.) train_top1_correct = torch.as_tensor(0) train_top5_correct = torch.as_tensor(0) tic = time() model.eval() torch.autograd.set_grad_enabled(False) for idx, (images, labels) in enumerate(val_loader, start=1): images = images.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) preds = model(images) top1 = torch.argmax(preds, dim=1) _, top5 = torch.topk(preds, k=5) val_top1_correct += torch.sum(torch.eq(top1, labels)) val_top5_correct += torch.sum(torch.eq(top5, labels.unsqueeze(-1))) step_loss = loss_func(preds, labels) val_loss_sum += step_loss if verbose: print( f'Validation loss Epoch {epoch:03d} Step {idx:03d}: {step_loss.item()}' ) else: toc = int(time() - tic) epoch_loss = val_loss_sum.item() * batch_size / len( val_loader.dataset) epoch_top1_acc = val_top1_correct.item() / len( val_loader.dataset) * 100 epoch_top5_acc = val_top5_correct.item() / len( val_loader.dataset) * 100 msg = f'loss: {epoch_loss:.4f}, top1 accuracy: {epoch_top1_acc:.2f}%, top5 accuracy: {epoch_top5_acc:.2f}%' logger.info(f'Epoch {epoch:03d} Validation. {msg} Time: {toc}s') # writer.add_scalar('val_loss', epoch_loss, global_step=epoch) # writer.add_scalar('val_acc', epoch_top1_acc, global_step=epoch) # Reset to 0. There must be a better way though... val_loss_sum = torch.as_tensor(0.) val_top1_correct = torch.as_tensor(0) val_top5_correct = torch.as_tensor(0) # Checkpoint generation. Only implemented for single GPU models, not multi-gpu models. # All comparisons are done with python numbers, not tensors. if epoch_top5_acc > previous_best: # Assumes larger metric is better. logger.info( f'Top 5 Validation Accuracy in Epoch {epoch} has improved from ' f'{previous_best:.2f}% to {epoch_top5_acc:.2f}%') previous_best = epoch_top5_acc save_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() } torch.save(save_dict, ckpt_path / f'epoch_{epoch:03d}.tar') else: logger.info( f'Top 5 Validation Accuracy in Epoch {epoch} has not improved from the previous best epoch' ) if not save_best_only: save_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() } torch.save(save_dict, ckpt_path / f'epoch_{epoch:03d}.tar')
def __init__(self, args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_train_transform, output_val_transform, losses, scheduler=None): # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs. if multiprocessing.get_start_method(allow_none=True) is None: multiprocessing.set_start_method(method='spawn') self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name) # Checking whether inputs are correct. assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.' assert isinstance( optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.' assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \ '`train_loader` and `val_loader` must be Pytorch DataLoader objects.' assert callable(input_train_transform) and callable(input_val_transform), \ 'input_transforms must be callable functions.' # I think this would be best practice. assert isinstance(output_train_transform, nn.Module) and isinstance(output_val_transform, nn.Module), \ '`output_train_transform` and `output_val_transform` must be Pytorch Modules.' # 'losses' is expected to be a dictionary. # Even composite losses should be a single loss module with a tuple as its output. losses = nn.ModuleDict(losses) if scheduler is not None: if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.metric_scheduler = True elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): self.metric_scheduler = False else: raise TypeError( '`scheduler` must be a Pytorch Learning Rate Scheduler.') # Display interval of 0 means no display of validation images on TensorBoard. if args.max_images <= 0: self.display_interval = 0 else: self.display_interval = int( len(val_loader.dataset) // (args.max_images * args.batch_size)) self.manager = CheckpointManager(model, optimizer, mode='min', save_best_only=args.save_best_only, ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep) # loading from checkpoint if specified. if vars(args).get('prev_model_ckpt'): self.manager.load(load_dir=args.prev_model_ckpt, load_optimizer=False) self.model = model self.optimizer = optimizer self.train_loader = train_loader self.val_loader = val_loader self.input_train_transform = input_train_transform self.input_val_transform = input_val_transform self.output_train_transform = output_train_transform self.output_val_transform = output_val_transform self.losses = losses self.scheduler = scheduler self.writer = SummaryWriter(str(args.log_path)) self.verbose = args.verbose self.num_epochs = args.num_epochs self.smoothing_factor = args.smoothing_factor self.shrink_scale = args.shrink_scale self.use_slice_metrics = args.use_slice_metrics # This part should get SSIM, not 1 - SSIM. self.ssim = SSIM(filter_size=7).to( device=args.device) # Needed to cache the kernel. # Logging all components of the Model Trainer. # Train and Val input and output transforms are assumed to use the same input transform class. self.logger.info(f''' Summary of Model Trainer Components: Model: {get_class_name(model)}. Optimizer: {get_class_name(optimizer)}. Input Transforms: {get_class_name(input_val_transform)}. Output Transform: {get_class_name(output_val_transform)}. Image Domain Loss: {get_class_name(losses['img_loss'])}. Learning-Rate Scheduler: {get_class_name(scheduler)}. ''') # This part has parts different for IMG and CMG losses!!
def train_cmg_and_img(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__) if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) if args.random_sampling: # Same as in the challenge mask_func = RandomMaskFunc(args.center_fractions, args.accelerations) else: mask_func = UniformMaskFunc(args.center_fractions, args.accelerations) input_train_transform = PreProcessWSemiKCC(mask_func=mask_func, weight_func=no_weight, challenge=args.challenge, device=device, use_seed=False) input_val_transform = PreProcessWSemiKCC(mask_func=mask_func, weight_func=no_weight, challenge=args.challenge, device=device, use_seed=True) output_train_transform = PostProcessWSemiKCC( args.challenge, weighted=False, residual_acs=args.residual_acs) output_val_transform = PostProcessWSemiKCC(args.challenge, weighted=False, residual_acs=args.residual_acs) # DataLoaders train_loader, val_loader = create_prefetch_data_loaders(args) losses = dict( cmg_loss=nn.MSELoss(), # img_loss=SSIMLoss(filter_size=7).to(device=device) # img_loss=LogSSIMLoss(filter_size=7).to(device=device) img_loss=nn.L1Loss()) data_chans = 2 if args.challenge == 'singlecoil' else 30 # Multicoil has 15 coils with 2 for real/imag model = UNet(in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, num_depth_blocks=args.num_depth_blocks, num_groups=args.num_groups, negative_slope=args.negative_slope, use_residual=args.use_residual, interp_mode=args.interp_mode, use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = ModelTrainerCI(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_train_transform, output_val_transform, losses, scheduler) try: trainer.train_model() except KeyboardInterrupt: trainer.writer.close() logger.warning( f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.' )
def __init__(self, args, model, optimizer, train_loader, val_loader, post_processing, loss_func, metrics=None, scheduler=None): self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name) # Checking whether inputs are correct. assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.' assert isinstance( optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.' assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \ '`train_loader` and `val_loader` must be Pytorch DataLoader objects.' # I think this would be best practice. assert isinstance( post_processing, nn.Module), '`post_processing_func` must be a Pytorch Module.' # This is not a mistake. Pytorch implements loss functions as modules. assert isinstance( loss_func, nn.Module), '`loss_func` must be a callable Pytorch Module.' if metrics is not None: assert isinstance( metrics, (list, tuple)), '`metrics` must be a list or tuple.' for metric in metrics: assert callable( metric), 'All metrics must be callable functions.' if scheduler is not None: if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.metric_scheduler = True elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): self.metric_scheduler = False else: raise TypeError( '`scheduler` must be a Pytorch Learning Rate Scheduler.') self.model = model self.optimizer = optimizer self.train_loader = train_loader self.val_loader = val_loader self.post_processing_func = post_processing self.loss_func = loss_func # I don't think it is necessary to send loss_func or metrics to device. self.metrics = metrics self.scheduler = scheduler self.verbose = args.verbose self.num_epochs = args.num_epochs self.writer = SummaryWriter(logdir=str(args.log_path)) # Display interval of 0 means no display of validation images on TensorBoard. self.display_interval = int( len(self.val_loader.dataset) // args.max_images) if (args.max_images > 0) else 0 # # Writing model graph to TensorBoard. Results might not be very good. if args.add_graph: num_chans = 30 if args.challenge == 'multicoil' else 2 example_inputs = torch.ones(size=(1, num_chans, 640, 328), device=args.device) self.writer.add_graph(model=model, input_to_model=example_inputs) del example_inputs # Remove unnecessary tensor taking up memory. self.checkpointer = CheckpointManager( model=self.model, optimizer=self.optimizer, mode='min', save_best_only=args.save_best_only, ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep) # loading from checkpoint if specified. if vars(args).get('prev_model_ckpt'): self.checkpointer.load(load_dir=args.prev_model_ckpt, load_optimizer=False)
def train_model(model, args): assert isinstance(model, nn.Module) # Beginning session. run_number, run_name = initialize(args.ckpt_root) ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') else: device = torch.device('cpu') # Saving args for later use. save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) dataset_kwargs = dict(root=args.data_root, download=True) train_dataset = torchvision.datasets.CIFAR100(train=True, transform=train_transform(), **dataset_kwargs) val_dataset = torchvision.datasets.CIFAR100(train=False, transform=val_transform(), **dataset_kwargs) loader_kwargs = dict(batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs) val_loader = DataLoader(val_dataset, shuffle=False, **loader_kwargs) # Define model, optimizer, etc. model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) # No softmax layer at the end necessary. Just need logits. loss_func = nn.CrossEntropyLoss(reduction='mean').to(device) # LR scheduler scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) # Create checkpoint manager checkpointer = CheckpointManager(model, optimizer, mode='max', save_best_only=args.save_best_only, ckpt_dir=ckpt_path, max_to_keep=args.max_to_keep) # Tensorboard Writer writer = SummaryWriter(log_dir=str(log_path)) # Training loop. Please excuse my use of 1 based indexing here. logger.info('Beginning Training loop') for epoch in range(1, args.num_epochs + 1): # Start of training tic = time() train_loss_sum, train_top1_correct, train_top5_correct = \ train_epoch(model, optimizer, loss_func, train_loader, device, epoch, args.verbose) toc = int(time() - tic) # Last step with small batch causes some inaccuracy but that is tolerable. train_epoch_loss = train_loss_sum.item() * args.batch_size / len( train_loader.dataset) train_epoch_top1_acc = train_top1_correct.item() / len( train_loader.dataset) * 100 train_epoch_top5_acc = train_top5_correct.item() / len( train_loader.dataset) * 100 logger.info( f'Epoch {epoch:03d} Training. loss: {train_epoch_loss:.4e}, top1 accuracy: ' f'{train_epoch_top1_acc:.2f}%, top5 accuracy: {train_epoch_top5_acc:.2f}% Time: {toc}s' ) # Writing to Tensorboard writer.add_scalar('train_epoch_loss', train_epoch_loss, epoch) writer.add_scalar('train_epoch_top1_acc', train_epoch_top1_acc, epoch) writer.add_scalar('train_epoch_top5_acc', train_epoch_top5_acc, epoch) # Start of evaluation tic = time() val_loss_sum, val_top1_correct, val_top5_correct = \ eval_epoch(model, loss_func, val_loader, device, epoch, args.verbose) toc = int(time() - tic) val_epoch_loss = val_loss_sum.item() * args.batch_size / len( val_loader.dataset) val_epoch_top1_acc = val_top1_correct.item() / len( val_loader.dataset) * 100 val_epoch_top5_acc = val_top5_correct.item() / len( val_loader.dataset) * 100 logger.info( f'Epoch {epoch:03d} Validation. loss: {val_epoch_loss:.4e}, top1 accuracy: ' f'{val_epoch_top1_acc:.2f}%, top5 accuracy: {val_epoch_top5_acc:.2f}% Time: {toc}s' ) # Writing to Tensorboard writer.add_scalar('val_epoch_loss', val_epoch_loss, epoch) writer.add_scalar('val_epoch_top1_acc', val_epoch_top1_acc, epoch) writer.add_scalar('val_epoch_top5_acc', val_epoch_top5_acc, epoch) for idx, group in enumerate(optimizer.param_groups, start=1): writer.add_scalar(f'learning_rate_{idx}', group['lr'], epoch) # Things to do after each epoch. scheduler.step( ) # Reduces LR at the designated times. Probably does not use 1 indexing like me. checkpointer.save(metric=val_epoch_top5_acc)
def main(): # Put these in args later. batch_size = 12 num_workers = 8 init_lr = 2E-4 gpu = 0 # Set to None for CPU mode. num_epochs = 500 verbose = False save_best_only = True max_to_keep = 100 data_root = '/home/veritas/PycharmProjects/PA1/data' ckpt_root = '/home/veritas/PycharmProjects/PA1/checkpoints' log_root = '/home/veritas/PycharmProjects/PA1/logs' # Beginning session. run_number, run_name = initialize(ckpt_root) ckpt_path = Path(ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(log_root) log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) if (gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{gpu}') else: device = torch.device('cpu') # Do more fancy transforms later. transform = torchvision.transforms.ToTensor() train_dataset = torchvision.datasets.CIFAR100(data_root, train=True, transform=transform, download=True) val_dataset = torchvision.datasets.CIFAR100(data_root, train=False, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) # Define model, optimizer, etc. model = se_resnet50_cifar100().to(device, non_blocking=True) optimizer = optim.Adam(model.parameters(), lr=init_lr) # No softmax layer at the end necessary. Just need logits. loss_func = nn.CrossEntropyLoss().to(device, non_blocking=True) # Create checkpoint manager checkpointer = CheckpointManager(model, optimizer, ckpt_path, save_best_only, max_to_keep) # For recording data. previous_best = 0. # Accuracy should improve. # Training loop. Please excuse my use of 1 based indexing here. logger.info('Beginning Training loop') for epoch in range(1, num_epochs + 1): # Start of training tic = time() train_loss_sum, train_top1_correct, train_top5_correct = \ train_epoch(model, optimizer, loss_func, train_loader, device, epoch, verbose) toc = int(time() - tic) # Last step with small batch causes some inaccuracy but that is tolerable. train_epoch_loss = train_loss_sum.item() * batch_size / len( train_loader.dataset) train_epoch_top1_acc = train_top1_correct.item() / len( train_loader.dataset) * 100 train_epoch_top5_acc = train_top5_correct.item() / len( train_loader.dataset) * 100 msg = f'Epoch {epoch:03d} Training. loss: {train_epoch_loss:.4f}, ' \ f'top1 accuracy: {train_epoch_top1_acc:.2f}%, top5 accuracy: {train_epoch_top5_acc:.2f}% Time: {toc}s' logger.info(msg) # Start of evaluation tic = time() val_loss_sum, val_top1_correct, val_top5_correct = \ eval_epoch(model, loss_func, val_loader, device, epoch, verbose) toc = int(time() - tic) val_epoch_loss = val_loss_sum.item() * batch_size / len( val_loader.dataset) val_epoch_top1_acc = val_top1_correct.item() / len( val_loader.dataset) * 100 val_epoch_top5_acc = val_top5_correct.item() / len( val_loader.dataset) * 100 msg = f'Epoch {epoch:03d} Validation. loss: {val_epoch_loss:.4f}, ' \ f'top1 accuracy: {val_epoch_top1_acc:.2f}%, top5 accuracy: {val_epoch_top5_acc:.2f}% Time: {toc}s' logger.info(msg) if val_epoch_top5_acc > previous_best: # Assumes larger metric is better. logger.info( f'Top 5 Validation Accuracy in Epoch {epoch} has improved from ' f'{previous_best:.2f}% to {val_epoch_top5_acc:.2f}%') previous_best = val_epoch_top5_acc checkpointer.save(is_best=True) else: logger.info( f'Top 5 Validation Accuracy in Epoch {epoch} has not improved from the previous best epoch' ) checkpointer.save(is_best=False)
def main(argv): # argv are non-flag parameters del argv start = time() tf.print('Tensorflow Engaged') run_number, run_name = initialize(FLAGS.ckpt_dir) log_path = Path(FLAGS.log_dir) log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=False) logger = get_logger(__name__) data_path = Path(FLAGS.data_dir) train_path = data_path / f'{FLAGS.challenge}_train' val_path = data_path / f'{FLAGS.challenge}_val' normalize = False def loss_func(labels, predictions): return 1 - tf.image.ssim(labels, predictions, max_val=15.0) # Sort of heuristic method... train_dataset = HDF5Sequence(data_dir=train_path, batch_size=FLAGS.batch_size, training=True, normalize=normalize) val_dataset = HDF5Sequence(data_dir=val_path, batch_size=FLAGS.batch_size, training=False, normalize=normalize) model = make_unet_model(scope='UNET', input_shape=(320, 320, 1)) tf.keras.utils.plot_model(model, to_file=log_path / f'model_{run_number:02d}.png', show_shapes=True) model.summary() ckpt_path = Path(FLAGS.ckpt_dir) / run_name ckpt_path.mkdir(exist_ok=True) optimizer = tf.train.AdamOptimizer( learning_rate=FLAGS.lr) # Can't change lr easily until TF2.0 update... checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) if FLAGS.restore_dir: latest_checkpoint = tf.train.latest_checkpoint(FLAGS.restore_dir) checkpoint.restore(latest_checkpoint).assert_nontrivial_match( ).assert_existing_objects_matched() print(f'Restored model from {latest_checkpoint}') manager = tf.train.CheckpointManager(checkpoint=checkpoint, directory=str(ckpt_path), max_to_keep=FLAGS.max_to_keep) writer = tf.contrib.summary.create_file_writer( logdir=str(log_path)) # Graph display not possible in eager... model_config = json.loads(model.to_json()) # Save FLAGS to a json file next to the tensorboard data save_dict_as_json(dict_data=FLAGS.flag_values_dict(), log_dir=str(log_path), save_name=f'{run_name}_FLAGS') save_dict_as_json(dict_data=model_config, log_dir=str(log_path), save_name=f'{run_name}_config') with writer.as_default(), tf.contrib.summary.always_record_summaries(): train_and_eval(model=model, optimizer=optimizer, manager=manager, train_dataset=train_dataset, val_dataset=val_dataset, num_epochs=FLAGS.num_epochs, loss_func=loss_func, save_best_only=FLAGS.save_best_only, use_train_metrics=True, use_val_metrics=True, verbose=FLAGS.verbose, max_images=FLAGS.max_images) finish = int(time() - start) logger.info( f'Finished Training model. Time: {finish // 3600:02d}hrs {(finish // 60) % 60}min {finish % 60}s' ) return 0
def __init__(self, args, model, optimizer, train_loader, val_loader, post_processing, c_loss, metrics=None, scheduler=None): multiprocessing.set_start_method(method='spawn') self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name) # Checking whether inputs are correct. assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.' assert isinstance( optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.' assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \ '`train_loader` and `val_loader` must be Pytorch DataLoader objects.' # I think this would be best practice. assert isinstance( post_processing, nn.Module), '`post_processing` must be a Pytorch Module.' # This is not a mistake. Pytorch implements loss functions as modules. assert isinstance( c_loss, nn.Module), '`c_loss` must be a callable Pytorch Module.' if metrics is not None: assert isinstance( metrics, Iterable ), '`metrics` must be an iterable, preferably a list or tuple.' for metric in metrics: assert callable( metric), 'All metrics must be callable functions.' if scheduler is not None: if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.metric_scheduler = True elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): self.metric_scheduler = False else: raise TypeError( '`scheduler` must be a Pytorch Learning Rate Scheduler.') self.model = model self.optimizer = optimizer self.train_loader = train_loader self.val_loader = val_loader self.post_processing_func = post_processing self.c_loss_func = c_loss self.metrics = metrics self.scheduler = scheduler self.verbose = args.verbose self.num_epochs = args.num_epochs self.writer = SummaryWriter(logdir=str(args.log_path)) # Display interval of 0 means no display of validation images on TensorBoard. if args.max_images <= 0: self.display_interval = 0 else: self.display_interval = int( len(self.val_loader.dataset) // (args.max_images * args.batch_size)) self.checkpointer = CheckpointManager( model=self.model, optimizer=self.optimizer, mode='min', save_best_only=args.save_best_only, ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep) # loading from checkpoint if specified. if vars(args).get('prev_model_ckpt'): self.checkpointer.load(load_dir=args.prev_model_ckpt, load_optimizer=False)
def main(args): ckpt_path = Path('checkpoints') ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) args.ckpt_path = ckpt_path log_path = Path('logs') log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) args.log_path = log_path save_dict_as_json(vars(save_dict_as_json), log_dir=log_path, save_name=run_name) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): args.device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: args.device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Create Datasets train_dataset = TrainingDataset(args.train_root, transform=slice_normalize_and_clip, single_coil=False) val_dataset = TrainingDataset( # Shrink val time by half. Consistent comparison remains. args.val_root, transform=slice_normalize_and_clip, single_coil=False, use_double=False, seed=9872) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) # Define model, optimizer, and loss function. model = UnetModel(in_chans=1, out_chans=1, chans=32, num_pool_layers=4).to(args.device, non_blocking=True) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) loss_func = nn.L1Loss(reduction='mean').to(args.device, non_blocking=True) checkpointer = CheckpointManager(model, optimizer, args.save_best_only, ckpt_path, args.max_to_keep) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True, cooldown=0, min_lr=1E-7) writer = SummaryWriter(log_dir=str(log_path)) example_inputs = torch.ones(size=(args.batch_size, 1, 320, 320)).to(args.device) writer.add_graph(model=model, input_to_model=example_inputs, verbose=False) previous_best = np.inf # Training loop. Please excuse my use of 1 based indexing here. logger.info('Beginning Training loop') for epoch in range(1, args.num_epochs + 1): tic = time() model.train() torch.autograd.set_grad_enabled(True) logger.info(f'Beginning training for Epoch {epoch:03d}') # Reset to 0. There must be a better way though... train_loss_sum = torch.as_tensor(0.) tot_len = len(train_loader.dataset) // args.batch_size for idx, (ds_imgs, labels) in tqdm(enumerate(train_loader, start=1), total=tot_len): ds_imgs = ds_imgs.to(args.device).unsqueeze(dim=1) labels = labels.to(args.device).unsqueeze(dim=1) optimizer.zero_grad() pred_residuals = model(ds_imgs) # Predicted residuals are outputs. recons = pred_residuals + ds_imgs step_loss = loss_func(recons, labels) step_loss.backward() optimizer.step() with torch.no_grad( ): # Necessary for metric calculation without errors due to gradient accumulation. train_loss_sum += step_loss if args.verbose: print( f'Training loss Epoch {epoch:03d} Step {idx:03d}: {step_loss.item()}' ) else: toc = int(time() - tic) # Last step with small batch causes some inaccuracy but that is tolerable. epoch_loss = train_loss_sum.item() * args.batch_size / len( train_loader.dataset) logger.info( f'Epoch {epoch:03d} Training. loss: {float(epoch_loss):.4f}, Time: {toc // 60}min {toc % 60}s' ) writer.add_scalar('train_loss', epoch_loss, global_step=epoch) tic = time() model.eval() torch.autograd.set_grad_enabled(False) logger.info(f'Beginning validation for Epoch {epoch:03d}') # Reset to 0. There must be a better way though... val_loss_sum = torch.as_tensor(0.) # Need to return residual and labels for residual learning. tot_len = len(val_loader.dataset) // args.batch_size for idx, (ds_imgs, labels) in tqdm(enumerate(val_loader, start=1), total=tot_len): ds_imgs = ds_imgs.to(args.device).unsqueeze(dim=1) labels = labels.to(args.device).unsqueeze(dim=1) pred_residuals = model(ds_imgs) recons = pred_residuals + ds_imgs step_loss = loss_func(recons, labels) val_loss_sum += step_loss if args.verbose: print( f'Validation loss Epoch {epoch:03d} Step {idx:03d}: {step_loss.item()}' ) if args.max_imgs: pass # TODO: Implement saving images to TensorBoard. else: toc = int(time() - tic) epoch_loss = val_loss_sum.item() * args.batch_size / len( val_loader.dataset) logger.info( f'Epoch {epoch:03d} Validation. loss: {float(epoch_loss):.4f} Time: {toc // 60}min {toc % 60}s' ) writer.add_scalar('val_loss', epoch_loss, global_step=epoch) scheduler.step(metrics=epoch_loss, epoch=epoch) # Changes optimizer lr. # Checkpoint generation. Only implemented for single GPU models, not multi-gpu models. # All comparisons are done with python numbers, not tensors. if epoch_loss < previous_best: # Assumes smaller metric is better. logger.info( f'Loss in Epoch {epoch} has improved from {float(previous_best):.4f} to {float(epoch_loss):.4f}' ) previous_best = epoch_loss checkpointer.save(is_best=True) else: logger.info( f'Loss in Epoch {epoch} has not improved from the previous best epoch' ) checkpointer.save( is_best=False) # No save if save_best_only is True.
def train_and_eval(model, optimizer, manager, train_dataset, val_dataset, num_epochs, loss_func, save_best_only=True, use_train_metrics=False, use_val_metrics=True, verbose=True, max_images=36): logger = get_logger(__name__) prev_loss = 2**30 for epoch in range(1, num_epochs + 1): # Training tic = time() logger.info(f'\nStarting Epoch {epoch:03d} Training') train_epoch_loss, train_epoch_metrics = \ train_epoch(model=model, optimizer=optimizer, dataset=train_dataset, loss_func=loss_func, epoch=epoch, use_metrics=use_train_metrics, verbose=verbose) # After Epoch training is over. toc = int(time() - tic) logger.info( f'Epoch {epoch:03d} Training Finished. Time: {toc // 60}min {toc % 60}s.' ) tf.contrib.summary.scalar(name='train_epoch_loss', tensor=train_epoch_loss, step=epoch) logger.info(f'Epoch Training loss: {float(train_epoch_loss):.4f}') if use_train_metrics: logger.info(f'Epoch Training Metrics:') for idx, metric in enumerate(train_epoch_metrics, start=1): tf.contrib.summary.scalar(name=f'train_metric_{idx}', tensor=metric, step=epoch) logger.info(f'Train Metric {idx}: {float(metric):.4f}') # Validation tic = time() logger.info(f'\nStarting Epoch {epoch:03d} Validation') val_epoch_loss, val_epoch_metrics = \ val_epoch(model=model, dataset=val_dataset, loss_func=loss_func, epoch=epoch, max_images=max_images, verbose=verbose, use_metrics=use_val_metrics) # After Epoch validation is over. toc = int(time() - tic) tf.contrib.summary.scalar(name='val_epoch_loss', tensor=val_epoch_loss, step=epoch) logger.info( f'Epoch {epoch:03d} Validation Finished. Time: {toc // 60}min {toc % 60}s.' ) logger.info(f'Epoch Validation loss: {float(val_epoch_loss):.4f}') if use_val_metrics: logger.info(f'Epoch Validation Metrics:') for idx, metric in enumerate(val_epoch_metrics, start=1): tf.contrib.summary.scalar(name=f'val_metric_{idx}', tensor=metric, step=epoch) logger.info(f'Validation Metric {idx}: {float(metric):.4f}') # Checkpoint the Epoch if there has been improvement. # Not possible when the loss function keeps changing... if val_epoch_loss < prev_loss: prev_loss = val_epoch_loss logger.info('Validation loss has improved from previous epoch') logger.info(f'Last checkpoint file: {manager.latest_checkpoint}') manager.save(checkpoint_number=epoch) else: logger.info('Validation loss has not improved from previous epoch') logger.info(f'Previous minimum loss: {float(prev_loss):.4f}') if not save_best_only: manager.save(checkpoint_number=epoch)
def train_img(args): # Maybe move this to args later. train_method = 'K2CI' # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # Input transforms. These are on a per-slice basis. # UNET architecture requires that all inputs be dividable by some power of 2. divisor = 2**args.num_pool_layers mask_func = MaskFunc(args.center_fractions, args.accelerations) data_prefetch = Prefetch2Device(device) input_train_transform = TrainPreProcessK(mask_func, args.challenge, args.device, use_seed=False, divisor=divisor) input_val_transform = TrainPreProcessK(mask_func, args.challenge, args.device, use_seed=True, divisor=divisor) # DataLoaders train_loader, val_loader = create_custom_data_loaders( args, transform=data_prefetch) losses = dict( cmg_loss=nn.MSELoss(reduction='mean'), # img_loss=L1CSSIM7(reduction='mean', alpha=0.5) img_loss=CSSIM(filter_size=7, reduction='mean')) output_transform = OutputReplaceTransformK() data_chans = 2 if args.challenge == 'singlecoil' else 30 # Multicoil has 15 coils with 2 for real/imag model = UnetASE(in_chans=data_chans, out_chans=data_chans, ext_chans=args.chans, chans=args.chans, num_pool_layers=args.num_pool_layers, min_ext_size=args.min_ext_size, max_ext_size=args.max_ext_size, use_ext_bias=args.use_ext_bias, use_att=False).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_red_epoch, gamma=args.lr_red_rate) trainer = ModelTrainerK2CI(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_transform, losses, scheduler) trainer.train_model()
def __init__(self, args, generator, discriminator, gen_optim, disc_optim, train_loader, val_loader, loss_funcs, gen_scheduler=None, disc_scheduler=None): self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name) # Checking whether inputs are correct. assert isinstance(generator, nn.Module) and isinstance(discriminator, nn.Module), \ '`generator` and `discriminator` must be Pytorch Modules.' assert isinstance(gen_optim, optim.Optimizer) and isinstance(disc_optim, optim.Optimizer), \ '`gen_optim` and `disc_optim` must be Pytorch Optimizers.' assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \ '`train_loader` and `val_loader` must be Pytorch DataLoader objects.' loss_funcs = nn.ModuleDict( loss_funcs ) # Expected to be a dictionary with names and loss functions. if gen_scheduler is not None: if isinstance(gen_scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.metric_gen_scheduler = True elif isinstance(gen_scheduler, optim.lr_scheduler._LRScheduler): self.metric_gen_scheduler = False else: raise TypeError( '`gen_scheduler` must be a Pytorch Learning Rate Scheduler.' ) if disc_scheduler is not None: if isinstance(disc_scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.metric_disc_scheduler = True elif isinstance(disc_scheduler, optim.lr_scheduler._LRScheduler): self.metric_disc_scheduler = False else: raise TypeError( '`disc_scheduler` must be a Pytorch Learning Rate Scheduler.' ) self.generator = generator self.discriminator = discriminator self.gen_optim = gen_optim self.disc_optim = disc_optim self.train_loader = train_loader self.val_loader = val_loader self.loss_funcs = loss_funcs self.gen_scheduler = gen_scheduler self.disc_scheduler = disc_scheduler self.device = args.device self.verbose = args.verbose self.num_epochs = args.num_epochs self.writer = SummaryWriter(str(args.log_path)) self.recon_lambda = torch.tensor(args.recon_lambda, dtype=torch.float32, device=args.device) self.lambda_gp = torch.tensor(args.lambda_gp, dtype=torch.float32, device=args.device) # This will work best if batch size is 1, as is recommended. I don't know whether this generalizes. self.target_real = torch.tensor(1, dtype=torch.float32, device=args.device) self.target_fake = torch.tensor(0, dtype=torch.float32, device=args.device) # Display interval of 0 means no display of validation images on TensorBoard. if args.max_images <= 0: self.display_interval = 0 else: self.display_interval = int( len(self.val_loader.dataset) // (args.max_images * args.batch_size)) self.gen_checkpoint_manager = CheckpointManager( model=self.generator, optimizer=self.gen_optim, mode='min', save_best_only=args.save_best_only, ckpt_dir=args.ckpt_path / 'Generator', max_to_keep=args.max_to_keep) self.disc_checkpoint_manager = CheckpointManager( model=self.discriminator, optimizer=self.disc_optim, mode='min', save_best_only=args.save_best_only, ckpt_dir=args.ckpt_path / 'Discriminator', max_to_keep=args.max_to_keep) # loading from checkpoint if specified. if vars(args).get('gen_prev_model_ckpt'): self.gen_checkpoint_manager.load(load_dir=args.gen_prev_model_ckpt, load_optimizer=False) if vars(args).get('disc_prev_model_ckpt'): self.disc_checkpoint_manager.load( load_dir=args.disc_prev_model_ckpt, load_optimizer=False)
def train_k2i(args): # Maybe move this to args later. train_method = 'W2I' # Weighted K-space to real-valued image. # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # Input transforms. These are on a per-slice basis. # UNET architecture requires that all inputs be dividable by some power of 2. divisor = 2**args.num_pool_layers if args.random_sampling: mask_func = MaskFunc(args.center_fractions, args.accelerations) else: mask_func = UniformMaskFunc(args.center_fractions, args.accelerations) # This is optimized for SSD storage. # Sending to device should be inside the input transform for optimal performance on HDD. data_prefetch = Prefetch2Device(device) input_train_transform = WeightedPreProcessK(mask_func, args.challenge, device, use_seed=False, divisor=divisor) input_val_transform = WeightedPreProcessK(mask_func, args.challenge, device, use_seed=True, divisor=divisor) # DataLoaders train_loader, val_loader = create_custom_data_loaders( args, transform=data_prefetch) losses = dict(img_loss=nn.L1Loss(reduction='mean') # img_loss=L1CSSIM7(reduction='mean', alpha=args.alpha) ) output_transform = WeightedReplacePostProcessK() data_chans = 2 if args.challenge == 'singlecoil' else 30 # Multicoil has 15 coils with 2 for real/imag model = UNetSkipGN(in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, num_groups=args.num_groups, pool_type=args.pool_type, use_skip=args.use_skip, use_att=args.use_att, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = ModelTrainerK2I(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_transform, losses, scheduler) trainer.train_model()
def train_xnet(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__) if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) arguments = vars( args) # Placed here for backward compatibility and convenience. args.center_fractions_train = arguments.get( 'center_fractions_train', arguments.get('center_fractions')) args.center_fractions_val = arguments.get( 'center_fractions_val', arguments.get('center_fractions')) args.accelerations_train = arguments.get('accelerations_train', arguments.get('accelerations')) args.accelerations_val = arguments.get('accelerations_val', arguments.get('accelerations')) if args.random_sampling: train_mask_func = RandomMaskFunc(args.center_fractions_train, args.accelerations_train) val_mask_func = RandomMaskFunc(args.center_fractions_val, args.accelerations_val) else: train_mask_func = UniformMaskFunc(args.center_fractions_train, args.accelerations_train) val_mask_func = UniformMaskFunc(args.center_fractions_val, args.accelerations_val) input_train_transform = PreProcessXNet(mask_func=train_mask_func, challenge=args.challenge, device=device, augment_data=args.augment_data, use_seed=False, crop_center=args.crop_center) input_val_transform = PreProcessXNet(mask_func=val_mask_func, challenge=args.challenge, device=device, augment_data=False, use_seed=True, crop_center=args.crop_center) output_train_transform = PostProcessXNet(challenge=args.challenge) output_val_transform = PostProcessXNet(challenge=args.challenge) # DataLoaders train_loader, val_loader = create_prefetch_data_loaders(args) losses = dict( phase_loss=nn.MSELoss(), # img_loss=SSIMLoss(filter_size=7).to(device=device) img_loss=LogSSIMLoss(filter_size=5).to(device=device), # img_loss=nn.L1Loss() x_loss=AlignmentLoss()) data_chans = 1 if args.challenge == 'singlecoil' else 15 # Multicoil has 15 coils with 2 for real/imag model = XNet(in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, num_depth_blocks=args.num_depth_blocks, dilation=args.dilation, res_scale=args.res_scale).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = XNetModelTrainer(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_train_transform, output_val_transform, losses, scheduler) try: trainer.train_model() except KeyboardInterrupt: trainer.writer.close() logger.warning( f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.' )
def train_gan_model(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_dir) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_dir) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Please note that many objects (such as Path objects) cannot be serialized to json files. save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device train_transform = InputTrainTransform(is_training=True) val_transform = InputTrainTransform(is_training=False) # DataLoaders train_loader, val_loader = create_data_loaders(args, train_transform, val_transform) # Loss Function and output post-processing functions. gan_loss_func = nn.BCELoss(reduction='mean') recon_loss_func = L1CSSIM(l1_weight=args.l1_weight, default_range=12, filter_size=7, reduction='mean') # recon_loss_func = nn.L1Loss(reduction='mean') loss_funcs = { 'gan_loss_func': gan_loss_func, 'recon_loss_func': recon_loss_func } # Define model. generator = UnetModel(1, 1, args.chans, args.num_pool_layers, drop_prob=0).to(device) discriminator = SimpleCNN(chans=args.chans).to(device) gen_optim = optim.Adam(params=generator.parameters(), lr=args.init_lr) disc_optim = optim.Adam(params=discriminator.parameters(), lr=args.init_lr) gen_scheduler = optim.lr_scheduler.StepLR(gen_optim, step_size=args.step_size, gamma=args.lr_reduction_rate) disc_scheduler = optim.lr_scheduler.StepLR(disc_optim, step_size=args.step_size, gamma=args.lr_reduction_rate) trainer = GANModelTrainer(args, generator, discriminator, gen_optim, disc_optim, train_loader, val_loader, loss_funcs, gen_scheduler, disc_scheduler) trainer.train_model()