from utils.device import setup_device from utils.losses import ReconstructionLoss, SceneLatentLoss, LightLatentLoss, GANLoss, FoolGANLoss, ColorPredictionLoss, DirectionPredictionLoss import utils.tensorboard as tensorboard from utils.dataset import InputTargetGroundtruthDataset, TRAIN_DATA_PATH, VALIDATION_DATA_PATH from torch.utils.data import DataLoader from models.swapModels import SwapNet512x1x1 as SwapNet #from models.swapModels import IlluminationSwapNet as SwapNet #from models.swapModels import AnOtherSwapNet as SwapNet from models.swapModels import IlluminationPredicter from models.patchGan import NLayerDiscriminator # Get used device GPU_IDS = [3] device = setup_device(GPU_IDS) # Parameters NAME = 'SwapNet512x1x1WithoutGANIlluminationPredictor-testOverfit' TRAIN_BATCH_SIZE = 20 TRAIN_NUM_WORKERS = 8 TEST_BATCH_SIZE = 20 TEST_NUM_WORKERS = 8 SIZE = 256 TRAIN_DURATION = 60000 # Configure training generator = SwapNet(last_kernel_size=1).to(device) illumination = IlluminationPredicter(in_size=16).to(device) discriminator = NLayerDiscriminator().to(device) optimizerG = torch.optim.Adam(list(generator.parameters()) +
def main(config): # Device to use device = setup_device(config["gpus"]) # Configure training objects # Generator model_name = config["model"] generator = get_generator_model(model_name)().to(device) weight_decay = config["L2_regularization_generator"] if config["use_illumination_predicter"]: light_latent_size = get_light_latent_size(model_name) illumination_predicter = IlluminationPredicter( in_size=light_latent_size).to(device) optimizerG = torch.optim.Adam( list(generator.parameters()) + list(illumination_predicter.parameters()), weight_decay=weight_decay) else: optimizerG = torch.optim.Adam(generator.parameters(), weight_decay=weight_decay) # Discriminator if config["use_discriminator"]: if config["discriminator_everything_as_input"]: raise NotImplementedError # TODO else: discriminator = NLayerDiscriminator().to(device) optimizerD = torch.optim.Adam( discriminator.parameters(), weight_decay=config["L2_regularization_discriminator"]) # Losses reconstruction_loss = ReconstructionLoss().to(device) if config["use_illumination_predicter"]: color_prediction_loss = ColorPredictionLoss().to(device) direction_prediction_loss = DirectionPredictionLoss().to(device) if config["use_discriminator"]: gan_loss = GANLoss().to(device) fool_gan_loss = FoolGANLoss().to(device) # Metrics if "scene_latent" in config["metrics"]: scene_latent_loss = SceneLatentLoss().to(device) if "light_latent" in config["metrics"]: light_latent_loss = LightLatentLoss().to(device) if "LPIPS" in config["metrics"]: lpips_loss = LPIPS( net_type= 'alex', # choose a network type from ['alex', 'squeeze', 'vgg'] version='0.1' # Currently, v0.1 is supported ).to(device) # Configure dataloader size = config['image_resize'] # train try: file = open( 'traindataset' + str(config['overfit_test']) + str(size) + '.pickle', 'rb') print("Restoring train dataset from pickle file") train_dataset = pickle.load(file) file.close() print("Restored train dataset from pickle file") except: train_dataset = InputTargetGroundtruthDataset( transform=transforms.Resize(size), data_path=TRAIN_DATA_PATH, locations=['scene_abandonned_city_54'] if config['overfit_test'] else None, input_directions=["S", "E"] if config['overfit_test'] else None, target_directions=["S", "E"] if config['overfit_test'] else None, input_colors=["2500", "6500"] if config['overfit_test'] else None, target_colors=["2500", "6500"] if config['overfit_test'] else None) file = open( "traindataset" + str(config['overfit_test']) + str(size) + '.pickle', 'wb') pickle.dump(train_dataset, file) file.close() print("saved traindataset" + str(config['overfit_test']) + str(size) + '.pickle') train_dataloader = DataLoader(train_dataset, batch_size=config['train_batch_size'], shuffle=config['shuffle_data'], num_workers=config['train_num_workers']) # test try: file = open( "testdataset" + str(config['overfit_test']) + str(size) + '.pickle', 'rb') print("Restoring full test dataset from pickle file") test_dataset = pickle.load(file) file.close() print("Restored full test dataset from pickle file") except: test_dataset = InputTargetGroundtruthDataset( transform=transforms.Resize(size), data_path=VALIDATION_DATA_PATH, locations=["scene_city_24"] if config['overfit_test'] else None, input_directions=["S", "E"] if config['overfit_test'] else None, target_directions=["S", "E"] if config['overfit_test'] else None, input_colors=["2500", "6500"] if config['overfit_test'] else None, target_colors=["2500", "6500"] if config['overfit_test'] else None) file = open( "testdataset" + str(config['overfit_test']) + str(size) + '.pickle', 'wb') pickle.dump(test_dataset, file) file.close() print("saved testdataset" + str(config['overfit_test']) + str(size) + '.pickle') test_dataloader = DataLoader(test_dataset, batch_size=config['test_batch_size'], shuffle=config['shuffle_data'], num_workers=config['test_num_workers']) test_dataloaders = {"full": test_dataloader} if config["testing_on_subsets"]: additional_pairing_strategies = [[SameLightColor()], [SameLightDirection()]] #[SameScene()], #[SameScene(), SameLightColor()], #[SameScene(), SameLightDirection()], #[SameLightDirection(), SameLightColor()], #[SameScene(), SameLightDirection(), SameLightColor()]] for pairing_strategies in additional_pairing_strategies: try: file = open( "testdataset" + str(config['overfit_test']) + str(size) + str(pairing_strategies) + '.pickle', 'rb') print("Restoring test dataset " + str(pairing_strategies) + " from pickle file") test_dataset = pickle.load(file) file.close() print("Restored test dataset " + str(pairing_strategies) + " from pickle file") except: test_dataset = InputTargetGroundtruthDataset( transform=transforms.Resize(size), data_path=VALIDATION_DATA_PATH, pairing_strategies=pairing_strategies, locations=["scene_city_24"] if config['overfit_test'] else None, input_directions=["S", "E"] if config['overfit_test'] else None, target_directions=["S", "E"] if config['overfit_test'] else None, input_colors=["2500", "6500"] if config['overfit_test'] else None, target_colors=["2500", "6500"] if config['overfit_test'] else None) file = open( "testdataset" + str(config['overfit_test']) + str(size) + str(pairing_strategies) + '.pickle', 'wb') pickle.dump(test_dataset, file) file.close() print("saved testdataset" + str(config['overfit_test']) + str(size) + str(pairing_strategies) + '.pickle') test_dataloader = DataLoader( test_dataset, batch_size=config['test_batch_size'], shuffle=config['shuffle_data'], num_workers=config['test_num_workers']) test_dataloaders[str(pairing_strategies)] = test_dataloader print( f'Dataset contains {len(train_dataset)} train samples and {len(test_dataset)} test samples.' ) print( f'{config["shown_samples_grid"]} samples will be visualized every {config["testing_frequence"]} batches.' ) print( f'Evaluation will be made every {config["testing_frequence"]} batches on {config["batches_for_testing"]} batches' ) # Configure tensorboard writer = tensorboard.setup_summary_writer(config['name']) tensorboard_process = tensorboard.start_tensorboard_process( ) # TODO: config["tensorboard_port"] # Train loop # Init train scalars (train_generator_loss, train_discriminator_loss, train_score, train_lpips, train_ssim, train_psnr, train_scene_latent_loss_input_gt, train_scene_latent_loss_input_target, train_scene_latent_loss_gt_target, train_light_latent_loss_input_gt, train_light_latent_loss_input_target, train_light_latent_loss_gt_target, train_color_prediction_loss, train_direction_prediction_loss) = (0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.) # Init train loop train_dataloader_iter = iter(train_dataloader) train_batches_counter = 0 print(f'Running for {config["train_duration"]} batches.') last_save_t = 0 # Train loop while train_batches_counter < config['train_duration']: # Store trained model t = time.time() if t - last_save_t > config["checkpoint_period"]: last_save_t = t save_trained(generator, "generator" + config['name'] + str(t)) if config["use_illumination_predicter"]: save_trained( illumination_predicter, "illumination_predicter" + config['name'] + str(t)) if config["use_discriminator"]: save_trained(discriminator, "discriminator" + config['name'] + str(t)) #with torch.autograd.detect_anomaly(): # Load batch if config["debug"]: print('Load batch', get_gpu_memory_map()) with torch.no_grad(): train_batch, train_dataloader_iter = next_batch( train_dataloader_iter, train_dataloader) (input_image, target_image, groundtruth_image, input_color, target_color, groundtruth_color, input_direction, target_direction, groundtruth_direction) = extract_from_batch(train_batch, device) # Generator # Generator: Forward if config["debug"]: print('Generator: Forward', get_gpu_memory_map()) output = generator(input_image, target_image, groundtruth_image) (relit_image, input_light_latent, target_light_latent, groundtruth_light_latent, input_scene_latent, target_scene_latent, groundtruth_scene_latent) = output r = reconstruction_loss(relit_image, groundtruth_image) generator_loss = config['generator_loss_reconstruction_l2_factor'] * r if config["use_illumination_predicter"]: input_illumination = illumination_predicter(input_light_latent) target_illumination = illumination_predicter(target_light_latent) groundtruth_illumination = illumination_predicter( groundtruth_light_latent) c = (1 / 3 * color_prediction_loss(input_illumination[:, 0], input_color) + 1 / 3 * color_prediction_loss(target_illumination[:, 0], target_color) + 1 / 3 * color_prediction_loss( groundtruth_illumination[:, 0], groundtruth_color)) d = (1 / 3 * direction_prediction_loss(input_illumination[:, 1], input_direction) + 1 / 3 * direction_prediction_loss(target_illumination[:, 1], target_direction) + 1 / 3 * direction_prediction_loss( groundtruth_illumination[:, 1], groundtruth_direction)) generator_loss += config['generator_loss_color_l2_factor'] * c generator_loss += config['generator_loss_direction_l2_factor'] * d train_color_prediction_loss += c.item() train_direction_prediction_loss += d.item() train_generator_loss += generator_loss.item() train_score += reconstruction_loss( input_image, groundtruth_image).item() / reconstruction_loss( relit_image, groundtruth_image).item() if "scene_latent" in config["metrics"]: train_scene_latent_loss_input_gt += scene_latent_loss( input_image, groundtruth_image).item() train_scene_latent_loss_input_target += scene_latent_loss( input_image, target_image).item() train_scene_latent_loss_gt_target += scene_latent_loss( target_image, groundtruth_image).item() if "light_latent" in config["metrics"]: train_light_latent_loss_input_gt += light_latent_loss( input_image, groundtruth_image).item() train_light_latent_loss_input_target += light_latent_loss( input_image, target_image).item() train_light_latent_loss_gt_target += light_latent_loss( target_image, groundtruth_image).item() if "LPIPS" in config["metrics"]: train_lpips += lpips_loss(relit_image, groundtruth_image).item() if "SSIM" in config["metrics"]: train_ssim += ssim(relit_image, groundtruth_image).item() if "PSNR" in config["metrics"]: train_psnr += psnr(relit_image, groundtruth_image).item() # Generator: Backward if config["debug"]: print('Generator: Backward', get_gpu_memory_map()) optimizerG.zero_grad() if config["use_discriminator"]: optimizerD.zero_grad() if config["use_discriminator"]: discriminator.zero_grad() generator_loss.backward( ) # use requires_grad = False for speed? Et pour enlever ces zero_grad en double! # Generator: Update parameters if config["debug"]: print('Generator: Update parameters', get_gpu_memory_map()) optimizerG.step() # Discriminator if config["use_discriminator"]: if config["debug"]: print('Discriminator', get_gpu_memory_map()) # Discriminator : Forward output = generator(input_image, target_image, groundtruth_image) (relit_image, input_light_latent, target_light_latent, groundtruth_light_latent, input_scene_latent, target_scene_latent, groundtruth_scene_latent) = output disc_out_fake = discriminator(relit_image) disc_out_real = discriminator(groundtruth_image) discriminator_loss = config[ 'discriminator_loss_gan_factor'] * gan_loss( disc_out_fake, disc_out_real) train_discriminator_loss += discriminator_loss.item() # Discriminator : Backward optimizerD.zero_grad() discriminator.zero_grad() optimizerG.zero_grad() generator.zero_grad() discriminator_loss.backward() generator.zero_grad() optimizerG.zero_grad() # Discriminator : Update parameters optimizerD.step() # Update train_batches_counter train_batches_counter += 1 # If it is time to do so, test and visualize current progress step, modulo = divmod(train_batches_counter, config['testing_frequence']) if modulo == 0: with torch.no_grad(): # Visualize train if config["debug"]: print('Visualize train', get_gpu_memory_map()) write_images( writer=writer, header="Train", step=step, inputs=input_image[:config['shown_samples_grid']], input_light_latents=input_light_latent[:config[ 'shown_samples_grid']], targets=target_image[:config['shown_samples_grid']], target_light_latents=target_light_latent[:config[ 'shown_samples_grid']], groundtruths=groundtruth_image[:config[ 'shown_samples_grid']], groundtruth_light_latents=groundtruth_light_latent[:config[ 'shown_samples_grid']], relits=relit_image[:config['shown_samples_grid']]) write_measures( writer=writer, header="Train", step=step, generator_loss=train_generator_loss / config['testing_frequence'], discriminator_loss=train_discriminator_loss / config['testing_frequence'], score=train_score / config['testing_frequence'], ssim=train_ssim / config['testing_frequence'], lpips=train_lpips / config['testing_frequence'], psnr=train_psnr / config['testing_frequence'], scene_input_gt=train_scene_latent_loss_input_gt / config['testing_frequence'], scene_input_target=train_scene_latent_loss_input_target / config['testing_frequence'], scene_gt_target=train_scene_latent_loss_gt_target / config['testing_frequence'], light_input_gt=train_light_latent_loss_input_gt / config['testing_frequence'], light_input_target=train_light_latent_loss_input_target / config['testing_frequence'], light_gt_target=train_light_latent_loss_gt_target / config['testing_frequence'], color_prediction=train_color_prediction_loss / config['testing_frequence'], direction_prediction=train_direction_prediction_loss / config['testing_frequence']) print('Train', 'Loss:', train_generator_loss / config['testing_frequence'], 'Score:', train_score / config['testing_frequence']) if config["debug_memory"]: print(get_gpu_memory_map()) # del generator_loss # torch.cuda.empty_cache() # print(get_gpu_memory_map()) # Reset train scalars if config["debug"]: print('Reset train scalars', get_gpu_memory_map()) (train_generator_loss, train_discriminator_loss, train_score, train_lpips, train_ssim, train_psnr, train_scene_latent_loss_input_gt, train_scene_latent_loss_input_target, train_scene_latent_loss_gt_target, train_light_latent_loss_input_gt, train_light_latent_loss_input_target, train_light_latent_loss_gt_target, train_color_prediction_loss, train_direction_prediction_loss) = (0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.) # Test loop if config["debug"]: print('Test loop', get_gpu_memory_map()) for header, test_dataloader in test_dataloaders.items(): # Init test scalars if config["debug"]: print('Init test scalars', get_gpu_memory_map()) (test_generator_loss, test_discriminator_loss, test_score, test_lpips, test_ssim, test_psnr, test_scene_latent_loss_input_gt, test_scene_latent_loss_input_target, test_scene_latent_loss_gt_target, test_light_latent_loss_input_gt, test_light_latent_loss_input_target, test_light_latent_loss_gt_target, test_color_prediction_loss, test_direction_prediction_loss) = (0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.) # Init test loop if config["debug"]: print('Init test loop', get_gpu_memory_map()) test_dataloader_iter = iter(test_dataloader) testing_batches_counter = 0 while testing_batches_counter < config[ 'batches_for_testing']: # Load batch if config["debug"]: print('Load batch', get_gpu_memory_map()) test_batch, test_dataloader_iter = next_batch( test_dataloader_iter, test_dataloader) (input_image, target_image, groundtruth_image, input_color, target_color, groundtruth_color, input_direction, target_direction, groundtruth_direction) = extract_from_batch( test_batch, device) # Forward # Generator if config["debug"]: print('Generator', get_gpu_memory_map()) output = generator(input_image, target_image, groundtruth_image) (relit_image, input_light_latent, target_light_latent, groundtruth_light_latent, input_scene_latent, target_scene_latent, groundtruth_scene_latent) = output r = reconstruction_loss(relit_image, groundtruth_image) generator_loss = config[ 'generator_loss_reconstruction_l2_factor'] * r if config["use_illumination_predicter"]: input_illumination = illumination_predicter( input_light_latent) target_illumination = illumination_predicter( target_light_latent) groundtruth_illumination = illumination_predicter( groundtruth_light_latent) c = (1 / 3 * color_prediction_loss( input_illumination[:, 0], input_color) + 1 / 3 * color_prediction_loss( target_illumination[:, 0], target_color) + 1 / 3 * color_prediction_loss( groundtruth_illumination[:, 0], groundtruth_color)) d = (1 / 3 * direction_prediction_loss( input_illumination[:, 1], input_direction) + 1 / 3 * direction_prediction_loss( target_illumination[:, 1], target_direction) + 1 / 3 * direction_prediction_loss( groundtruth_illumination[:, 1], groundtruth_direction)) generator_loss += config[ 'generator_loss_color_l2_factor'] * c generator_loss += config[ 'generator_loss_direction_l2_factor'] * d test_color_prediction_loss += c.item() test_direction_prediction_loss += d.item() test_generator_loss += generator_loss.item() test_score += reconstruction_loss( input_image, groundtruth_image).item() / reconstruction_loss( relit_image, groundtruth_image).item() if "scene_latent" in config["metrics"]: test_scene_latent_loss_input_gt += scene_latent_loss( input_image, groundtruth_image).item() test_scene_latent_loss_input_target += scene_latent_loss( input_image, target_image).item() test_scene_latent_loss_gt_target += scene_latent_loss( target_image, groundtruth_image).item() if "light_latent" in config["metrics"]: test_light_latent_loss_input_gt += light_latent_loss( input_image, groundtruth_image).item() test_light_latent_loss_input_target += light_latent_loss( input_image, target_image).item() test_light_latent_loss_gt_target += light_latent_loss( target_image, groundtruth_image).item() if "LPIPS" in config["metrics"]: test_lpips += lpips_loss(relit_image, groundtruth_image).item() if "SSIM" in config["metrics"]: test_ssim += ssim(relit_image, groundtruth_image).item() if "PSNR" in config["metrics"]: test_psnr += psnr(relit_image, groundtruth_image).item() # Discriminator if config["debug"]: print('Discriminator', get_gpu_memory_map()) if config["use_discriminator"]: disc_out_fake = discriminator(relit_image) disc_out_real = discriminator(groundtruth_image) discriminator_loss = config[ 'discriminator_loss_gan_factor'] * gan_loss( disc_out_fake, disc_out_real) test_discriminator_loss += discriminator_loss.item( ) # Update testing_batches_counter if config["debug"]: print('Update testing_batches_counter', get_gpu_memory_map()) testing_batches_counter += 1 # Visualize test if config["debug"]: print('Visualize test', get_gpu_memory_map()) write_images( writer=writer, header="Test-" + header, step=step, inputs=input_image[:config['shown_samples_grid']], input_light_latents=input_light_latent[:config[ 'shown_samples_grid']], targets=target_image[:config['shown_samples_grid']], target_light_latents=target_light_latent[:config[ 'shown_samples_grid']], groundtruths=groundtruth_image[:config[ 'shown_samples_grid']], groundtruth_light_latents= groundtruth_light_latent[: config['shown_samples_grid']], relits=relit_image[:config['shown_samples_grid']]) write_measures( writer=writer, header="Test-" + header, step=step, generator_loss=test_generator_loss / config['batches_for_testing'], discriminator_loss=test_discriminator_loss / config['batches_for_testing'], score=test_score / config['batches_for_testing'], ssim=test_ssim / config['batches_for_testing'], lpips=test_lpips / config['batches_for_testing'], psnr=test_psnr / config['batches_for_testing'], scene_input_gt=test_scene_latent_loss_input_gt / config['batches_for_testing'], scene_input_target=test_scene_latent_loss_input_target / config['batches_for_testing'], scene_gt_target=test_scene_latent_loss_gt_target / config['batches_for_testing'], light_input_gt=test_light_latent_loss_input_gt / config['batches_for_testing'], light_input_target=test_light_latent_loss_input_target / config['batches_for_testing'], light_gt_target=test_light_latent_loss_gt_target / config['batches_for_testing'], color_prediction=test_color_prediction_loss / config['batches_for_testing'], direction_prediction=test_direction_prediction_loss / config['batches_for_testing']) print('Test-' + header, 'Loss:', test_generator_loss / config['testing_frequence'], 'Score:', test_score / config['testing_frequence']) if config["debug_memory"]: print(get_gpu_memory_map())