def train_cmg_to_img(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) if args.random_sampling: mask_func = RandomMaskFunc(args.center_fractions, args.accelerations) else: mask_func = UniformMaskFunc(args.center_fractions, args.accelerations) input_train_transform = PreProcessCMG(mask_func, args.challenge, device, augment_data=args.augment_data, use_seed=False, crop_center=args.crop_center) input_val_transform = PreProcessCMG(mask_func, args.challenge, device, augment_data=False, use_seed=True, crop_center=args.crop_center) output_train_transform = PostProcessCMG(challenge=args.challenge, residual_acs=args.residual_acs) output_val_transform = PostProcessCMG(challenge=args.challenge, residual_acs=args.residual_acs) # DataLoaders train_loader, val_loader = create_prefetch_data_loaders(args) losses = dict( img_loss=LogSSIMLoss(filter_size=7).to(device) # img_loss=SSIMLoss(filter_size=7).to(device=device) # img_loss=nn.L1Loss() ) # model = UNet( # in_chans=30, out_chans=30, chans=args.chans, num_pool_layers=args.num_pool_layers, num_groups=args.num_groups, # negative_slope=args.negative_slope, use_residual=args.use_residual, interp_mode=args.interp_mode, # use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device) model = UNet(in_chans=30, out_chans=30, chans=args.chans, num_pool_layers=args.num_pool_layers, num_depth_blocks=args.num_depth_blocks, use_residual=args.use_residual, use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = ModelTrainerIMG(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_train_transform, output_val_transform, losses, scheduler) try: trainer.train_model() except KeyboardInterrupt: trainer.writer.close() logger.warning('Closing summary writer due to KeyboardInterrupt.')
def train_xnet(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__) if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) arguments = vars( args) # Placed here for backward compatibility and convenience. args.center_fractions_train = arguments.get( 'center_fractions_train', arguments.get('center_fractions')) args.center_fractions_val = arguments.get( 'center_fractions_val', arguments.get('center_fractions')) args.accelerations_train = arguments.get('accelerations_train', arguments.get('accelerations')) args.accelerations_val = arguments.get('accelerations_val', arguments.get('accelerations')) if args.random_sampling: train_mask_func = RandomMaskFunc(args.center_fractions_train, args.accelerations_train) val_mask_func = RandomMaskFunc(args.center_fractions_val, args.accelerations_val) else: train_mask_func = UniformMaskFunc(args.center_fractions_train, args.accelerations_train) val_mask_func = UniformMaskFunc(args.center_fractions_val, args.accelerations_val) input_train_transform = PreProcessXNet(mask_func=train_mask_func, challenge=args.challenge, device=device, augment_data=args.augment_data, use_seed=False, crop_center=args.crop_center) input_val_transform = PreProcessXNet(mask_func=val_mask_func, challenge=args.challenge, device=device, augment_data=False, use_seed=True, crop_center=args.crop_center) output_train_transform = PostProcessXNet(challenge=args.challenge) output_val_transform = PostProcessXNet(challenge=args.challenge) # DataLoaders train_loader, val_loader = create_prefetch_data_loaders(args) losses = dict( phase_loss=nn.MSELoss(), # img_loss=SSIMLoss(filter_size=7).to(device=device) img_loss=LogSSIMLoss(filter_size=5).to(device=device), # img_loss=nn.L1Loss() x_loss=AlignmentLoss()) data_chans = 1 if args.challenge == 'singlecoil' else 15 # Multicoil has 15 coils with 2 for real/imag model = XNet(in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, num_depth_blocks=args.num_depth_blocks, dilation=args.dilation, res_scale=args.res_scale).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = XNetModelTrainer(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_train_transform, output_val_transform, losses, scheduler) try: trainer.train_model() except KeyboardInterrupt: trainer.writer.close() logger.warning( f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.' )
def train_complex(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # Input transforms. These are on a per-slice basis. # UNET architecture requires that all inputs be dividable by some power of 2. divisor = 2**args.num_pool_layers if args.random_sampling: mask_func = RandomMaskFunc(args.center_fractions, args.accelerations) else: mask_func = UniformMaskFunc(args.center_fractions, args.accelerations) data_prefetch = Prefetch2Device(device) if args.train_method == 'WS2C': # Semi-k-space learning. weight_func = SemiDistanceWeight(weight_type=args.weight_type) input_train_transform = PreProcessWSK(mask_func, weight_func, args.challenge, device, use_seed=False, divisor=divisor) input_val_transform = PreProcessWSK(mask_func, weight_func, args.challenge, device, use_seed=True, divisor=divisor) output_transform = WeightedReplacePostProcessSemiK( weighted=True, replace=args.replace) elif args.train_method == 'WK2C': # k-space learning. weight_func = TiltedDistanceWeight(weight_type=args.weight_type, y_scale=args.y_scale) input_train_transform = PreProcessWK(mask_func, weight_func, args.challenge, device, use_seed=False, divisor=divisor) input_val_transform = PreProcessWK(mask_func, weight_func, args.challenge, device, use_seed=True, divisor=divisor) output_transform = WeightedReplacePostProcessK(weighted=True, replace=args.replace) else: raise NotImplementedError('Invalid train method!') # DataLoaders train_loader, val_loader = create_custom_data_loaders( args, transform=data_prefetch) losses = dict(cmg_loss=nn.MSELoss(reduction='mean')) data_chans = 2 if args.challenge == 'singlecoil' else 30 # Multicoil has 15 coils with 2 for real/imag # model = UNetModel( # in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, # num_groups=args.num_groups, use_residual=args.use_residual, pool_type=args.pool_type, use_skip=args.use_skip, # use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp, # use_sa=args.use_sa, sa_kernel_size=args.sa_kernel_size, sa_dilation=args.sa_dilation, use_cap=args.use_cap, # use_cmp=args.use_cmp).to(device) model = UNetModelKSSE(in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, num_groups=args.num_groups, use_residual=args.use_residual, pool_type=args.pool_type, use_skip=args.use_skip, min_ext_size=args.min_ext_size, max_ext_size=args.max_ext_size, ext_mode=args.ext_mode, use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp, use_sa=args.use_sa, sa_kernel_size=args.sa_kernel_size, sa_dilation=args.sa_dilation, use_cap=args.use_cap, use_cmp=args.use_cmp).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = ModelTrainerCOMPLEX(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_transform, losses, scheduler) trainer.train_model()
def train_img_to_rss(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) arguments = vars( args) # Placed here for backward compatibility and convenience. args.center_fractions_train = arguments.get( 'center_fractions_train', arguments.get('center_fractions')) args.center_fractions_val = arguments.get( 'center_fractions_val', arguments.get('center_fractions')) args.accelerations_train = arguments.get('accelerations_train', arguments.get('accelerations')) args.accelerations_val = arguments.get('accelerations_val', arguments.get('accelerations')) if args.random_sampling: train_mask_func = RandomMaskFunc(args.center_fractions_train, args.accelerations_train) val_mask_func = RandomMaskFunc(args.center_fractions_val, args.accelerations_val) else: train_mask_func = UniformMaskFunc(args.center_fractions_train, args.accelerations_train) val_mask_func = UniformMaskFunc(args.center_fractions_val, args.accelerations_val) input_train_transform = PreProcessRSS(mask_func=train_mask_func, challenge=args.challenge, device=device, augment_data=args.augment_data, use_seed=False, fat_info=args.fat_info) input_val_transform = PreProcessRSS(mask_func=val_mask_func, challenge=args.challenge, device=device, augment_data=False, use_seed=True, fat_info=args.fat_info) output_train_transform = PostProcessRSS(challenge=args.challenge, residual_rss=args.residual_rss) output_val_transform = PostProcessRSS(challenge=args.challenge, residual_rss=args.residual_rss) # DataLoaders train_loader, val_loader = create_prefetch_data_loaders(args) losses = dict(rss_loss=SSIMLoss(filter_size=7).to(device=device)) in_chans = 16 if args.fat_info else 15 model = UNet(in_chans=in_chans, out_chans=1, chans=args.chans, num_pool_layers=args.num_pool_layers, num_res_groups=args.num_res_groups, num_res_blocks_per_group=args.num_res_blocks_per_group, growth_rate=args.growth_rate, num_dense_layers=args.num_dense_layers, use_dense_ca=args.use_dense_ca, num_res_layers=args.num_res_layers, res_scale=args.res_scale, reduction=args.reduction, thick_base=args.thick_base).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=args.milestones, gamma=args.lr_red_rate) trainer = ModelTrainerRSS(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_train_transform, output_val_transform, losses, scheduler) try: trainer.train_model_concat() except KeyboardInterrupt: trainer.writer.close() logger.warning('Closing summary writer due to KeyboardInterrupt.')
def train_k2i(args): # Maybe move this to args later. train_method = 'W2I' # Weighted K-space to real-valued image. # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # Input transforms. These are on a per-slice basis. # UNET architecture requires that all inputs be dividable by some power of 2. divisor = 2**args.num_pool_layers if args.random_sampling: mask_func = MaskFunc(args.center_fractions, args.accelerations) else: mask_func = UniformMaskFunc(args.center_fractions, args.accelerations) # This is optimized for SSD storage. # Sending to device should be inside the input transform for optimal performance on HDD. data_prefetch = Prefetch2Device(device) input_train_transform = WeightedPreProcessK(mask_func, args.challenge, device, use_seed=False, divisor=divisor) input_val_transform = WeightedPreProcessK(mask_func, args.challenge, device, use_seed=True, divisor=divisor) # DataLoaders train_loader, val_loader = create_custom_data_loaders( args, transform=data_prefetch) losses = dict(img_loss=nn.L1Loss(reduction='mean') # img_loss=L1CSSIM7(reduction='mean', alpha=args.alpha) ) output_transform = WeightedReplacePostProcessK() data_chans = 2 if args.challenge == 'singlecoil' else 30 # Multicoil has 15 coils with 2 for real/imag model = UNetSkipGN(in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, num_groups=args.num_groups, pool_type=args.pool_type, use_skip=args.use_skip, use_att=args.use_att, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = ModelTrainerK2I(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_transform, losses, scheduler) trainer.train_model()
def train_cmg_to_img_direct(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # # UNET architecture requires that all inputs be dividable by some power of 2. # divisor = 2 ** args.num_pool_layers arguments = vars( args) # Placed here for backward compatibility and convenience. args.center_fractions_train = arguments.get( 'center_fractions_train', arguments.get('center_fractions')) args.center_fractions_val = arguments.get( 'center_fractions_val', arguments.get('center_fractions')) args.accelerations_train = arguments.get('accelerations_train', arguments.get('accelerations')) args.accelerations_val = arguments.get('accelerations_val', arguments.get('accelerations')) if args.random_sampling: train_mask_func = RandomMaskFunc(args.center_fractions_train, args.accelerations_train) val_mask_func = RandomMaskFunc(args.center_fractions_val, args.accelerations_val) else: train_mask_func = UniformMaskFunc(args.center_fractions_train, args.accelerations_train) val_mask_func = UniformMaskFunc(args.center_fractions_val, args.accelerations_val) input_train_transform = PreProcessCMGIMG(mask_func=train_mask_func, challenge=args.challenge, device=device, augment_data=args.augment_data, use_seed=False, crop_center=args.crop_center) input_val_transform = PreProcessCMGIMG(mask_func=val_mask_func, challenge=args.challenge, device=device, augment_data=False, use_seed=True, crop_center=args.crop_center) output_train_transform = PostProcessCMGIMG(challenge=args.challenge, output_mode='img') output_val_transform = PostProcessCMGIMG(challenge=args.challenge, output_mode='img') # DataLoaders train_loader, val_loader = create_prefetch_data_loaders(args) losses = dict( # img_loss=SSIMLoss(filter_size=7).to(device=device) # img_loss=LogSSIMLoss(filter_size=7).to(device=device) img_loss=nn.L1Loss() # img_loss=L1SSIMLoss(filter_size=7, l1_ratio=args.l1_ratio).to(device=device) ) # data_chans = 1 if args.challenge == 'singlecoil' else 15 model = UNet(in_chans=45, out_chans=15, chans=args.chans, num_pool_layers=args.num_pool_layers, num_depth_blocks=args.num_depth_blocks, res_scale=args.res_scale, use_residual=False, use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = ModelTrainerI2I(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_train_transform, output_val_transform, losses, scheduler) try: trainer.train_model() except KeyboardInterrupt: trainer.writer.close() logger.warning('Closing summary writer due to KeyboardInterrupt.')
def main(args): ckpt_path = Path('checkpoints') ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) args.ckpt_path = ckpt_path log_path = Path('logs') log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) args.log_path = log_path save_dict_as_json(vars(save_dict_as_json), log_dir=log_path, save_name=run_name) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): args.device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: args.device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Create Datasets train_dataset = TrainingDataset(args.train_root, transform=slice_normalize_and_clip, single_coil=False) val_dataset = TrainingDataset( # Shrink val time by half. Consistent comparison remains. args.val_root, transform=slice_normalize_and_clip, single_coil=False, use_double=False, seed=9872) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) # Define model, optimizer, and loss function. model = UnetModel(in_chans=1, out_chans=1, chans=32, num_pool_layers=4).to(args.device, non_blocking=True) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) loss_func = nn.L1Loss(reduction='mean').to(args.device, non_blocking=True) checkpointer = CheckpointManager(model, optimizer, args.save_best_only, ckpt_path, args.max_to_keep) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True, cooldown=0, min_lr=1E-7) writer = SummaryWriter(log_dir=str(log_path)) example_inputs = torch.ones(size=(args.batch_size, 1, 320, 320)).to(args.device) writer.add_graph(model=model, input_to_model=example_inputs, verbose=False) previous_best = np.inf # Training loop. Please excuse my use of 1 based indexing here. logger.info('Beginning Training loop') for epoch in range(1, args.num_epochs + 1): tic = time() model.train() torch.autograd.set_grad_enabled(True) logger.info(f'Beginning training for Epoch {epoch:03d}') # Reset to 0. There must be a better way though... train_loss_sum = torch.as_tensor(0.) tot_len = len(train_loader.dataset) // args.batch_size for idx, (ds_imgs, labels) in tqdm(enumerate(train_loader, start=1), total=tot_len): ds_imgs = ds_imgs.to(args.device).unsqueeze(dim=1) labels = labels.to(args.device).unsqueeze(dim=1) optimizer.zero_grad() pred_residuals = model(ds_imgs) # Predicted residuals are outputs. recons = pred_residuals + ds_imgs step_loss = loss_func(recons, labels) step_loss.backward() optimizer.step() with torch.no_grad( ): # Necessary for metric calculation without errors due to gradient accumulation. train_loss_sum += step_loss if args.verbose: print( f'Training loss Epoch {epoch:03d} Step {idx:03d}: {step_loss.item()}' ) else: toc = int(time() - tic) # Last step with small batch causes some inaccuracy but that is tolerable. epoch_loss = train_loss_sum.item() * args.batch_size / len( train_loader.dataset) logger.info( f'Epoch {epoch:03d} Training. loss: {float(epoch_loss):.4f}, Time: {toc // 60}min {toc % 60}s' ) writer.add_scalar('train_loss', epoch_loss, global_step=epoch) tic = time() model.eval() torch.autograd.set_grad_enabled(False) logger.info(f'Beginning validation for Epoch {epoch:03d}') # Reset to 0. There must be a better way though... val_loss_sum = torch.as_tensor(0.) # Need to return residual and labels for residual learning. tot_len = len(val_loader.dataset) // args.batch_size for idx, (ds_imgs, labels) in tqdm(enumerate(val_loader, start=1), total=tot_len): ds_imgs = ds_imgs.to(args.device).unsqueeze(dim=1) labels = labels.to(args.device).unsqueeze(dim=1) pred_residuals = model(ds_imgs) recons = pred_residuals + ds_imgs step_loss = loss_func(recons, labels) val_loss_sum += step_loss if args.verbose: print( f'Validation loss Epoch {epoch:03d} Step {idx:03d}: {step_loss.item()}' ) if args.max_imgs: pass # TODO: Implement saving images to TensorBoard. else: toc = int(time() - tic) epoch_loss = val_loss_sum.item() * args.batch_size / len( val_loader.dataset) logger.info( f'Epoch {epoch:03d} Validation. loss: {float(epoch_loss):.4f} Time: {toc // 60}min {toc % 60}s' ) writer.add_scalar('val_loss', epoch_loss, global_step=epoch) scheduler.step(metrics=epoch_loss, epoch=epoch) # Changes optimizer lr. # Checkpoint generation. Only implemented for single GPU models, not multi-gpu models. # All comparisons are done with python numbers, not tensors. if epoch_loss < previous_best: # Assumes smaller metric is better. logger.info( f'Loss in Epoch {epoch} has improved from {float(previous_best):.4f} to {float(epoch_loss):.4f}' ) previous_best = epoch_loss checkpointer.save(is_best=True) else: logger.info( f'Loss in Epoch {epoch} has not improved from the previous best epoch' ) checkpointer.save( is_best=False) # No save if save_best_only is True.
def train_img(args): # Maybe move this to args later. train_method = 'K2CI' # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # Input transforms. These are on a per-slice basis. # UNET architecture requires that all inputs be dividable by some power of 2. divisor = 2**args.num_pool_layers mask_func = MaskFunc(args.center_fractions, args.accelerations) data_prefetch = Prefetch2Device(device) input_train_transform = TrainPreProcessK(mask_func, args.challenge, args.device, use_seed=False, divisor=divisor) input_val_transform = TrainPreProcessK(mask_func, args.challenge, args.device, use_seed=True, divisor=divisor) # DataLoaders train_loader, val_loader = create_custom_data_loaders( args, transform=data_prefetch) losses = dict( cmg_loss=nn.MSELoss(reduction='mean'), # img_loss=L1CSSIM7(reduction='mean', alpha=0.5) img_loss=CSSIM(filter_size=7, reduction='mean')) output_transform = OutputReplaceTransformK() data_chans = 2 if args.challenge == 'singlecoil' else 30 # Multicoil has 15 coils with 2 for real/imag model = UnetASE(in_chans=data_chans, out_chans=data_chans, ext_chans=args.chans, chans=args.chans, num_pool_layers=args.num_pool_layers, min_ext_size=args.min_ext_size, max_ext_size=args.max_ext_size, use_ext_bias=args.use_ext_bias, use_att=False).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_red_epoch, gamma=args.lr_red_rate) trainer = ModelTrainerK2CI(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_transform, losses, scheduler) trainer.train_model()
def train_cmg_and_img(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__) if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) if args.random_sampling: # Same as in the challenge mask_func = RandomMaskFunc(args.center_fractions, args.accelerations) else: mask_func = UniformMaskFunc(args.center_fractions, args.accelerations) input_train_transform = PreProcessWSemiKCC(mask_func=mask_func, weight_func=no_weight, challenge=args.challenge, device=device, use_seed=False) input_val_transform = PreProcessWSemiKCC(mask_func=mask_func, weight_func=no_weight, challenge=args.challenge, device=device, use_seed=True) output_train_transform = PostProcessWSemiKCC( args.challenge, weighted=False, residual_acs=args.residual_acs) output_val_transform = PostProcessWSemiKCC(args.challenge, weighted=False, residual_acs=args.residual_acs) # DataLoaders train_loader, val_loader = create_prefetch_data_loaders(args) losses = dict( cmg_loss=nn.MSELoss(), # img_loss=SSIMLoss(filter_size=7).to(device=device) # img_loss=LogSSIMLoss(filter_size=7).to(device=device) img_loss=nn.L1Loss()) data_chans = 2 if args.challenge == 'singlecoil' else 30 # Multicoil has 15 coils with 2 for real/imag model = UNet(in_chans=data_chans, out_chans=data_chans, chans=args.chans, num_pool_layers=args.num_pool_layers, num_depth_blocks=args.num_depth_blocks, num_groups=args.num_groups, negative_slope=args.negative_slope, use_residual=args.use_residual, interp_mode=args.interp_mode, use_ca=args.use_ca, reduction=args.reduction, use_gap=args.use_gap, use_gmp=args.use_gmp).to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_red_epochs, gamma=args.lr_red_rate) trainer = ModelTrainerCI(args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_train_transform, output_val_transform, losses, scheduler) try: trainer.train_model() except KeyboardInterrupt: trainer.writer.close() logger.warning( f'Closing TensorBoard writer and flushing remaining outputs due to KeyboardInterrupt.' )
def main(): batch_size = 8 num_workers = 4 init_lr = 2E-4 gpu = 1 # Set to None for CPU mode. num_epochs = 500 verbose = False save_best_only = True data_root = '/home/veritas/PycharmProjects/PA1/data' log_root = '/home/veritas/PycharmProjects/PA1/logs' ckpt_root = '/home/veritas/PycharmProjects/PA1/checkpoints' ckpt_path = Path(ckpt_root) ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(log_root) log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) if (gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{gpu}') else: device = torch.device('cpu') # Do more fancy transforms later. transform = torchvision.transforms.ToTensor() train_dataset = torchvision.datasets.CIFAR100(data_root, train=True, transform=transform, download=True) val_dataset = torchvision.datasets.CIFAR100(data_root, train=False, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) # Define model, optimizer, etc. model = torchvision.models.resnet50(pretrained=False, num_classes=100).to(device, non_blocking=True) optimizer = optim.Adam(model.parameters(), lr=init_lr) # No softmax layer at the end necessary. Just need logits. loss_func = CrossEntropyLoss().to(device, non_blocking=True) # Getting data collection train_loss_sum = torch.as_tensor(0.) val_loss_sum = torch.as_tensor(0.) train_top1_correct = torch.as_tensor(0) val_top1_correct = torch.as_tensor(0) train_top5_correct = torch.as_tensor(0) val_top5_correct = torch.as_tensor(0) previous_best = 0. # Training loop. Please excuse my use of 1 based indexing here. logger.info('Beginning Training loop') for epoch in range(1, num_epochs + 1): tic = time() model.train() torch.autograd.set_grad_enabled(True) for idx, (images, labels) in enumerate(train_loader, start=1): images = images.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) optimizer.zero_grad() preds = model(images) # Pytorch uses (input, target) ordering. Their shapes are also different. step_loss = loss_func(preds, labels) step_loss.backward() optimizer.step() with torch.no_grad( ): # Necessary for metric calculation without errors due to gradient accumulation. train_loss_sum += step_loss top1 = torch.argmax(preds, dim=1) _, top5 = torch.topk(preds, k=5) # Get only the indices. train_top1_correct += torch.sum(torch.eq(top1, labels)) train_top5_correct += torch.sum( torch.eq(top5, labels.unsqueeze(-1))) # This is probably right. if verbose: print( f'Training loss Epoch {epoch:03d} Step {idx:03d}: {step_loss.item()}' ) else: toc = int(time() - tic) # Last step with small batch causes some inaccuracy but that is tolerable. epoch_loss = train_loss_sum.item() * batch_size / len( train_loader.dataset) epoch_top1_acc = train_top1_correct.item() / len( train_loader.dataset) * 100 epoch_top5_acc = train_top5_correct.item() / len( train_loader.dataset) * 100 msg = f'loss: {epoch_loss:.4f}, top1 accuracy: {epoch_top1_acc:.2f}%, top5 accuracy: {epoch_top5_acc:.2f}%' logger.info(f'Epoch {epoch:03d} Training. {msg} Time: {toc}s') # writer.add_scalar('train_loss', epoch_loss, global_step=epoch) # writer.add_scalar('train_acc', epoch_top1_acc, global_step=epoch) # Reset to 0. There must be a better way though... train_loss_sum = torch.as_tensor(0.) train_top1_correct = torch.as_tensor(0) train_top5_correct = torch.as_tensor(0) tic = time() model.eval() torch.autograd.set_grad_enabled(False) for idx, (images, labels) in enumerate(val_loader, start=1): images = images.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) preds = model(images) top1 = torch.argmax(preds, dim=1) _, top5 = torch.topk(preds, k=5) val_top1_correct += torch.sum(torch.eq(top1, labels)) val_top5_correct += torch.sum(torch.eq(top5, labels.unsqueeze(-1))) step_loss = loss_func(preds, labels) val_loss_sum += step_loss if verbose: print( f'Validation loss Epoch {epoch:03d} Step {idx:03d}: {step_loss.item()}' ) else: toc = int(time() - tic) epoch_loss = val_loss_sum.item() * batch_size / len( val_loader.dataset) epoch_top1_acc = val_top1_correct.item() / len( val_loader.dataset) * 100 epoch_top5_acc = val_top5_correct.item() / len( val_loader.dataset) * 100 msg = f'loss: {epoch_loss:.4f}, top1 accuracy: {epoch_top1_acc:.2f}%, top5 accuracy: {epoch_top5_acc:.2f}%' logger.info(f'Epoch {epoch:03d} Validation. {msg} Time: {toc}s') # writer.add_scalar('val_loss', epoch_loss, global_step=epoch) # writer.add_scalar('val_acc', epoch_top1_acc, global_step=epoch) # Reset to 0. There must be a better way though... val_loss_sum = torch.as_tensor(0.) val_top1_correct = torch.as_tensor(0) val_top5_correct = torch.as_tensor(0) # Checkpoint generation. Only implemented for single GPU models, not multi-gpu models. # All comparisons are done with python numbers, not tensors. if epoch_top5_acc > previous_best: # Assumes larger metric is better. logger.info( f'Top 5 Validation Accuracy in Epoch {epoch} has improved from ' f'{previous_best:.2f}% to {epoch_top5_acc:.2f}%') previous_best = epoch_top5_acc save_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() } torch.save(save_dict, ckpt_path / f'epoch_{epoch:03d}.tar') else: logger.info( f'Top 5 Validation Accuracy in Epoch {epoch} has not improved from the previous best epoch' ) if not save_best_only: save_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() } torch.save(save_dict, ckpt_path / f'epoch_{epoch:03d}.tar')
def train_model(model, args): assert isinstance(model, nn.Module) # Beginning session. run_number, run_name = initialize(args.ckpt_root) ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') else: device = torch.device('cpu') # Saving args for later use. save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) dataset_kwargs = dict(root=args.data_root, download=True) train_dataset = torchvision.datasets.CIFAR100(train=True, transform=train_transform(), **dataset_kwargs) val_dataset = torchvision.datasets.CIFAR100(train=False, transform=val_transform(), **dataset_kwargs) loader_kwargs = dict(batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs) val_loader = DataLoader(val_dataset, shuffle=False, **loader_kwargs) # Define model, optimizer, etc. model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) # No softmax layer at the end necessary. Just need logits. loss_func = nn.CrossEntropyLoss(reduction='mean').to(device) # LR scheduler scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) # Create checkpoint manager checkpointer = CheckpointManager(model, optimizer, mode='max', save_best_only=args.save_best_only, ckpt_dir=ckpt_path, max_to_keep=args.max_to_keep) # Tensorboard Writer writer = SummaryWriter(log_dir=str(log_path)) # Training loop. Please excuse my use of 1 based indexing here. logger.info('Beginning Training loop') for epoch in range(1, args.num_epochs + 1): # Start of training tic = time() train_loss_sum, train_top1_correct, train_top5_correct = \ train_epoch(model, optimizer, loss_func, train_loader, device, epoch, args.verbose) toc = int(time() - tic) # Last step with small batch causes some inaccuracy but that is tolerable. train_epoch_loss = train_loss_sum.item() * args.batch_size / len( train_loader.dataset) train_epoch_top1_acc = train_top1_correct.item() / len( train_loader.dataset) * 100 train_epoch_top5_acc = train_top5_correct.item() / len( train_loader.dataset) * 100 logger.info( f'Epoch {epoch:03d} Training. loss: {train_epoch_loss:.4e}, top1 accuracy: ' f'{train_epoch_top1_acc:.2f}%, top5 accuracy: {train_epoch_top5_acc:.2f}% Time: {toc}s' ) # Writing to Tensorboard writer.add_scalar('train_epoch_loss', train_epoch_loss, epoch) writer.add_scalar('train_epoch_top1_acc', train_epoch_top1_acc, epoch) writer.add_scalar('train_epoch_top5_acc', train_epoch_top5_acc, epoch) # Start of evaluation tic = time() val_loss_sum, val_top1_correct, val_top5_correct = \ eval_epoch(model, loss_func, val_loader, device, epoch, args.verbose) toc = int(time() - tic) val_epoch_loss = val_loss_sum.item() * args.batch_size / len( val_loader.dataset) val_epoch_top1_acc = val_top1_correct.item() / len( val_loader.dataset) * 100 val_epoch_top5_acc = val_top5_correct.item() / len( val_loader.dataset) * 100 logger.info( f'Epoch {epoch:03d} Validation. loss: {val_epoch_loss:.4e}, top1 accuracy: ' f'{val_epoch_top1_acc:.2f}%, top5 accuracy: {val_epoch_top5_acc:.2f}% Time: {toc}s' ) # Writing to Tensorboard writer.add_scalar('val_epoch_loss', val_epoch_loss, epoch) writer.add_scalar('val_epoch_top1_acc', val_epoch_top1_acc, epoch) writer.add_scalar('val_epoch_top5_acc', val_epoch_top5_acc, epoch) for idx, group in enumerate(optimizer.param_groups, start=1): writer.add_scalar(f'learning_rate_{idx}', group['lr'], epoch) # Things to do after each epoch. scheduler.step( ) # Reduces LR at the designated times. Probably does not use 1 indexing like me. checkpointer.save(metric=val_epoch_top5_acc)
def main(): # Put these in args later. batch_size = 12 num_workers = 8 init_lr = 2E-4 gpu = 0 # Set to None for CPU mode. num_epochs = 500 verbose = False save_best_only = True max_to_keep = 100 data_root = '/home/veritas/PycharmProjects/PA1/data' ckpt_root = '/home/veritas/PycharmProjects/PA1/checkpoints' log_root = '/home/veritas/PycharmProjects/PA1/logs' # Beginning session. run_number, run_name = initialize(ckpt_root) ckpt_path = Path(ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(log_root) log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) if (gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{gpu}') else: device = torch.device('cpu') # Do more fancy transforms later. transform = torchvision.transforms.ToTensor() train_dataset = torchvision.datasets.CIFAR100(data_root, train=True, transform=transform, download=True) val_dataset = torchvision.datasets.CIFAR100(data_root, train=False, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) # Define model, optimizer, etc. model = se_resnet50_cifar100().to(device, non_blocking=True) optimizer = optim.Adam(model.parameters(), lr=init_lr) # No softmax layer at the end necessary. Just need logits. loss_func = nn.CrossEntropyLoss().to(device, non_blocking=True) # Create checkpoint manager checkpointer = CheckpointManager(model, optimizer, ckpt_path, save_best_only, max_to_keep) # For recording data. previous_best = 0. # Accuracy should improve. # Training loop. Please excuse my use of 1 based indexing here. logger.info('Beginning Training loop') for epoch in range(1, num_epochs + 1): # Start of training tic = time() train_loss_sum, train_top1_correct, train_top5_correct = \ train_epoch(model, optimizer, loss_func, train_loader, device, epoch, verbose) toc = int(time() - tic) # Last step with small batch causes some inaccuracy but that is tolerable. train_epoch_loss = train_loss_sum.item() * batch_size / len( train_loader.dataset) train_epoch_top1_acc = train_top1_correct.item() / len( train_loader.dataset) * 100 train_epoch_top5_acc = train_top5_correct.item() / len( train_loader.dataset) * 100 msg = f'Epoch {epoch:03d} Training. loss: {train_epoch_loss:.4f}, ' \ f'top1 accuracy: {train_epoch_top1_acc:.2f}%, top5 accuracy: {train_epoch_top5_acc:.2f}% Time: {toc}s' logger.info(msg) # Start of evaluation tic = time() val_loss_sum, val_top1_correct, val_top5_correct = \ eval_epoch(model, loss_func, val_loader, device, epoch, verbose) toc = int(time() - tic) val_epoch_loss = val_loss_sum.item() * batch_size / len( val_loader.dataset) val_epoch_top1_acc = val_top1_correct.item() / len( val_loader.dataset) * 100 val_epoch_top5_acc = val_top5_correct.item() / len( val_loader.dataset) * 100 msg = f'Epoch {epoch:03d} Validation. loss: {val_epoch_loss:.4f}, ' \ f'top1 accuracy: {val_epoch_top1_acc:.2f}%, top5 accuracy: {val_epoch_top5_acc:.2f}% Time: {toc}s' logger.info(msg) if val_epoch_top5_acc > previous_best: # Assumes larger metric is better. logger.info( f'Top 5 Validation Accuracy in Epoch {epoch} has improved from ' f'{previous_best:.2f}% to {val_epoch_top5_acc:.2f}%') previous_best = val_epoch_top5_acc checkpointer.save(is_best=True) else: logger.info( f'Top 5 Validation Accuracy in Epoch {epoch} has not improved from the previous best epoch' ) checkpointer.save(is_best=False)
def main(argv): # argv are non-flag parameters del argv start = time() tf.print('Tensorflow Engaged') run_number, run_name = initialize(FLAGS.ckpt_dir) log_path = Path(FLAGS.log_dir) log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=False) logger = get_logger(__name__) data_path = Path(FLAGS.data_dir) train_path = data_path / f'{FLAGS.challenge}_train' val_path = data_path / f'{FLAGS.challenge}_val' normalize = False def loss_func(labels, predictions): return 1 - tf.image.ssim(labels, predictions, max_val=15.0) # Sort of heuristic method... train_dataset = HDF5Sequence(data_dir=train_path, batch_size=FLAGS.batch_size, training=True, normalize=normalize) val_dataset = HDF5Sequence(data_dir=val_path, batch_size=FLAGS.batch_size, training=False, normalize=normalize) model = make_unet_model(scope='UNET', input_shape=(320, 320, 1)) tf.keras.utils.plot_model(model, to_file=log_path / f'model_{run_number:02d}.png', show_shapes=True) model.summary() ckpt_path = Path(FLAGS.ckpt_dir) / run_name ckpt_path.mkdir(exist_ok=True) optimizer = tf.train.AdamOptimizer( learning_rate=FLAGS.lr) # Can't change lr easily until TF2.0 update... checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) if FLAGS.restore_dir: latest_checkpoint = tf.train.latest_checkpoint(FLAGS.restore_dir) checkpoint.restore(latest_checkpoint).assert_nontrivial_match( ).assert_existing_objects_matched() print(f'Restored model from {latest_checkpoint}') manager = tf.train.CheckpointManager(checkpoint=checkpoint, directory=str(ckpt_path), max_to_keep=FLAGS.max_to_keep) writer = tf.contrib.summary.create_file_writer( logdir=str(log_path)) # Graph display not possible in eager... model_config = json.loads(model.to_json()) # Save FLAGS to a json file next to the tensorboard data save_dict_as_json(dict_data=FLAGS.flag_values_dict(), log_dir=str(log_path), save_name=f'{run_name}_FLAGS') save_dict_as_json(dict_data=model_config, log_dir=str(log_path), save_name=f'{run_name}_config') with writer.as_default(), tf.contrib.summary.always_record_summaries(): train_and_eval(model=model, optimizer=optimizer, manager=manager, train_dataset=train_dataset, val_dataset=val_dataset, num_epochs=FLAGS.num_epochs, loss_func=loss_func, save_best_only=FLAGS.save_best_only, use_train_metrics=True, use_val_metrics=True, verbose=FLAGS.verbose, max_images=FLAGS.max_images) finish = int(time() - start) logger.info( f'Finished Training model. Time: {finish // 3600:02d}hrs {(finish // 60) % 60}min {finish % 60}s' ) return 0
def main(argv): # argv are non-flag parameters del argv start = time() tf.print('Tensorflow Engaged') run_number, run_name = initialize(FLAGS.ckpt_dir) log_path = Path(FLAGS.log_dir) log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=False) logger = get_logger(__name__) data_path = Path(FLAGS.data_dir) train_path = data_path / f'{FLAGS.challenge}_train' val_path = data_path / f'{FLAGS.challenge}_val' train_dataset = HDF5Sequence(data_dir=train_path, batch_size=FLAGS.batch_size, training=True, as_tensors=False) val_dataset = HDF5Sequence(data_dir=val_path, batch_size=FLAGS.batch_size, training=False, as_tensors=False) model = make_unet_model(scope='UNET', input_shape=(320, 320, 1)) # multi_model = tf.keras.utils.multi_gpu_model(model, gpus=2) multi_model = model tf.keras.utils.plot_model(model, to_file=log_path / f'model_{run_number:02d}.png', show_shapes=True) model.summary() ckpt_path = Path(FLAGS.ckpt_dir) / run_name ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / '{epoch:04d}.ckpt' ckpt_path = str(ckpt_path) optimizer = tf.keras.optimizers.Adam(lr=FLAGS.lr) checkpoint = ModelCheckpoint(filepath=ckpt_path, verbose=1, save_best_only=True) visualizer = TensorBoard(log_dir=log_path, batch_size=FLAGS.batch_size) model_config = json.loads(model.to_json()) # Save FLAGS to a json file next to the tensorboard data save_dict_as_json(dict_data=FLAGS.flag_values_dict(), log_dir=str(log_path), save_name=f'{run_name}_FLAGS') save_dict_as_json(dict_data=model_config, log_dir=str(log_path), save_name=f'{run_name}_config') multi_model.compile( optimizer=optimizer, loss='mae', metrics=[batch_ssim, batch_psnr, batch_msssim, batch_nmse]) multi_model.fit_generator(train_dataset, epochs=FLAGS.num_epochs, verbose=1, callbacks=[checkpoint, visualizer], validation_data=val_dataset, workers=4, use_multiprocessing=True) finish = int(time() - start) logger.info( f'Finished Training model. Time: {finish // 3600:02d}hrs {(finish // 60) % 60}min {finish % 60}s' ) return 0
def train_gan_model(args): # Creating checkpoint and logging directories, as well as the run name. ckpt_path = Path(args.ckpt_dir) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / args.train_method ckpt_path.mkdir(exist_ok=True) run_number, run_name = initialize(ckpt_path) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_dir) log_path.mkdir(exist_ok=True) log_path = log_path / args.train_method log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) # Assignment inside running code appears to work. if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') logger.info(f'Using GPU {args.gpu} for {run_name}') else: device = torch.device('cpu') logger.info(f'Using CPU for {run_name}') # Please note that many objects (such as Path objects) cannot be serialized to json files. save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) # Saving peripheral variables and objects in args to reduce clutter and make the structure flexible. args.run_number = run_number args.run_name = run_name args.ckpt_path = ckpt_path args.log_path = log_path args.device = device train_transform = InputTrainTransform(is_training=True) val_transform = InputTrainTransform(is_training=False) # DataLoaders train_loader, val_loader = create_data_loaders(args, train_transform, val_transform) # Loss Function and output post-processing functions. gan_loss_func = nn.BCELoss(reduction='mean') recon_loss_func = L1CSSIM(l1_weight=args.l1_weight, default_range=12, filter_size=7, reduction='mean') # recon_loss_func = nn.L1Loss(reduction='mean') loss_funcs = { 'gan_loss_func': gan_loss_func, 'recon_loss_func': recon_loss_func } # Define model. generator = UnetModel(1, 1, args.chans, args.num_pool_layers, drop_prob=0).to(device) discriminator = SimpleCNN(chans=args.chans).to(device) gen_optim = optim.Adam(params=generator.parameters(), lr=args.init_lr) disc_optim = optim.Adam(params=discriminator.parameters(), lr=args.init_lr) gen_scheduler = optim.lr_scheduler.StepLR(gen_optim, step_size=args.step_size, gamma=args.lr_reduction_rate) disc_scheduler = optim.lr_scheduler.StepLR(disc_optim, step_size=args.step_size, gamma=args.lr_reduction_rate) trainer = GANModelTrainer(args, generator, discriminator, gen_optim, disc_optim, train_loader, val_loader, loss_funcs, gen_scheduler, disc_scheduler) trainer.train_model()