def __init__(self, hparams: DictConfig = None, args: Namespace = None, loss_fns = None,): super().__init__() self.args = args self.hparams = hparams self.generator, self.discriminator = get_model(self.hparams) print(self.generator) print(self.discriminator) # init loss functions self.loss_fns = loss_fns if loss_fns else {'L2': l2_loss, #L2 loss 'ADV': GANLoss(hparams.gan_mode), # adversarial Loss 'G': l2_loss, # goal achievement loss 'GCE': nn.CrossEntropyLoss() }# Goal Cross Entropy loss # init loss weights self.loss_weights = {'L2': hparams.w_L2, 'ADV': hparams.w_ADV, # adversarial Loss 'G': hparams.w_G, # goal achievement loss 'GCE': hparams.w_GCE } # Goal Cross Entropy loss self.current_batch_idx = -1 self.plot_val = True if self.hparams.batch_size_scheduler: self.batch_size = self.hparams.batch_size_scheduler else: self.batch_size = self.hparams.batch_size
def __init__(self, generator_instantiator, discriminator=None, text_encoder_instantiator=None): super().__init__() self.generator = generator_instantiator() self.generator_running_avg = generator_instantiator() self.generator_running_avg.load_state_dict( self.generator.state_dict()) # Same initial weights for p in self.generator_running_avg.parameters(): p.requires_grad = False self.discriminator = discriminator self.criterion_gan = GANLoss(args.loss, tensor=torch.cuda.FloatTensor).cuda() if text_encoder_instantiator is not None: self.text_encoder_g = text_encoder_instantiator() total_params = 0 for param in self.text_encoder_g.parameters(): total_params += param.nelement() print('TextEncoder parameters: {:.2f}M'.format(total_params / 1000000)) if not args.evaluate: if not args.text_train_encoder: # G and D use the same text encoder instance self.text_encoder_d = self.text_encoder_g else: # Different instances for G and D self.text_encoder_d = text_encoder_instantiator()
def __init__(self, generator, train_dset, val_dset, cfg: DictConfig = None, loss_fns=None): super().__init__() self.cfg = cfg self.generator = generator self.generator.gen() # init loss functions self.loss_fns = loss_fns if loss_fns else { 'L2': l2_loss, # L2 loss 'ADV': GANLoss(cfg.gan_mode), # adversarial Loss 'G': l2_loss, # goal achievement loss 'GCE': nn.CrossEntropyLoss() } # Goal Cross Entropy loss # init loss weights self.loss_weights = { 'L2': cfg.w_L2, 'ADV': cfg.w_ADV, # adversarial Loss 'G': cfg.w_G, # goal achievement loss 'GCE': cfg.w_GCE } # Goal Cross Entropy loss self.train_dset = train_dset self.val_dset = val_dset self.plot_val = True if self.cfg.pretraining.batch_size_scheduler: self.batch_size = self.cfg.pretraining.batch_size_scheduler else: self.batch_size = self.cfg.batch_size
optG = get_optimizer(cfg.optimizer.generator)(params=netG.parameters()) optD = get_optimizer(cfg.optimizer.discriminator)(params=netD.parameters()) # set dataset, dataloader dataset = get_dataset(cfg) transform = get_transform(cfg) trainset = dataset(root=dataset_cfg.path, train=True, transform=transform.gan_training) trainloader = DataLoader(trainset, batch_size=dataset_cfg.batch_size, num_workers=0, shuffle=True) # set gan loss gan_loss = GANLoss(gan_mode=train_cfg.loss_type) # training, visualizing, saving iters = 0 for epoch in range(cfg.num_epochs): for i, data in enumerate(trainloader): start_time = time.time() real_imgs = data[0].cuda() # update discriminator for _ in range(train_cfg.args.n_dis): optD.zero_grad() z = torch.randn(real_imgs.size(0), cfg.generator.args.z_dim).cuda() gen_imgs = netG(z).detach() real_pred = netD(real_imgs) fake_pred = netD(gen_imgs)
optimizerG = torch.optim.Adam(list(generator.parameters()) + list(illumination.parameters()), weight_decay=0) optimizerD = torch.optim.Adam(discriminator.parameters(), weight_decay=0) generator.train() discriminator.train() print(generator) print(discriminator) # Losses reconstruction_loss = ReconstructionLoss().to(device) scene_latent_loss = SceneLatentLoss().to(device) light_latent_loss = LightLatentLoss().to(device) color_prediction_loss = ColorPredictionLoss().to(device) gan_loss = GANLoss().to(device) fool_gan_loss = FoolGANLoss().to(device) # Configure dataloader train_dataset = InputTargetGroundtruthDataset( transform=transforms.Resize(SIZE), data_path=TRAIN_DATA_PATH, locations=['scene_abandonned_city_54'], input_directions=["S"], target_directions=["N"], input_colors=["2500"], target_colors=["6500"]) test_dataset = InputTargetGroundtruthDataset(transform=transforms.Resize(SIZE), data_path=VALIDATION_DATA_PATH, input_directions=["S"], target_directions=["N"],
# set optimizers optE = get_optimizer(cfg.optimizer.encoder)(params=netE.parameters()) # set loss function image_criterion = RecLoss(cfg.train.image_loss.type) latent_criterion = RecLoss(cfg.train.latent_loss.type) # set weights w_image = cfg.train.image_loss.weight w_latent = cfg.train.latent_loss.weight if use_gan_loss: netD = get_discriminator(cfg, cfg.discriminator) netD.cuda() netD.train() optD = get_optimizer(cfg.optimizer.discriminator)(params=netD.parameters()) gan_criterion = GANLoss(gan_mode=cfg.train.adv_loss.type) w_adv = cfg.train.adv_loss.weight for iters in range(cfg.max_iters): start_time = time.time() if use_gan_loss: optD.zero_grad() z = torch.randn(dataset_cfg.batch_size, cfg.generator.args.z_dim).cuda() x_real = netG(z).detach() if cfg.train.denoise: noise = torch.randn_like(x_real) * 0.1 x_in = x_real + noise
def main(config): # Device to use device = setup_device(config["gpus"]) # Configure training objects # Generator model_name = config["model"] generator = get_generator_model(model_name)().to(device) weight_decay = config["L2_regularization_generator"] if config["use_illumination_predicter"]: light_latent_size = get_light_latent_size(model_name) illumination_predicter = IlluminationPredicter( in_size=light_latent_size).to(device) optimizerG = torch.optim.Adam( list(generator.parameters()) + list(illumination_predicter.parameters()), weight_decay=weight_decay) else: optimizerG = torch.optim.Adam(generator.parameters(), weight_decay=weight_decay) # Discriminator if config["use_discriminator"]: if config["discriminator_everything_as_input"]: raise NotImplementedError # TODO else: discriminator = NLayerDiscriminator().to(device) optimizerD = torch.optim.Adam( discriminator.parameters(), weight_decay=config["L2_regularization_discriminator"]) # Losses reconstruction_loss = ReconstructionLoss().to(device) if config["use_illumination_predicter"]: color_prediction_loss = ColorPredictionLoss().to(device) direction_prediction_loss = DirectionPredictionLoss().to(device) if config["use_discriminator"]: gan_loss = GANLoss().to(device) fool_gan_loss = FoolGANLoss().to(device) # Metrics if "scene_latent" in config["metrics"]: scene_latent_loss = SceneLatentLoss().to(device) if "light_latent" in config["metrics"]: light_latent_loss = LightLatentLoss().to(device) if "LPIPS" in config["metrics"]: lpips_loss = LPIPS( net_type= 'alex', # choose a network type from ['alex', 'squeeze', 'vgg'] version='0.1' # Currently, v0.1 is supported ).to(device) # Configure dataloader size = config['image_resize'] # train try: file = open( 'traindataset' + str(config['overfit_test']) + str(size) + '.pickle', 'rb') print("Restoring train dataset from pickle file") train_dataset = pickle.load(file) file.close() print("Restored train dataset from pickle file") except: train_dataset = InputTargetGroundtruthDataset( transform=transforms.Resize(size), data_path=TRAIN_DATA_PATH, locations=['scene_abandonned_city_54'] if config['overfit_test'] else None, input_directions=["S", "E"] if config['overfit_test'] else None, target_directions=["S", "E"] if config['overfit_test'] else None, input_colors=["2500", "6500"] if config['overfit_test'] else None, target_colors=["2500", "6500"] if config['overfit_test'] else None) file = open( "traindataset" + str(config['overfit_test']) + str(size) + '.pickle', 'wb') pickle.dump(train_dataset, file) file.close() print("saved traindataset" + str(config['overfit_test']) + str(size) + '.pickle') train_dataloader = DataLoader(train_dataset, batch_size=config['train_batch_size'], shuffle=config['shuffle_data'], num_workers=config['train_num_workers']) # test try: file = open( "testdataset" + str(config['overfit_test']) + str(size) + '.pickle', 'rb') print("Restoring full test dataset from pickle file") test_dataset = pickle.load(file) file.close() print("Restored full test dataset from pickle file") except: test_dataset = InputTargetGroundtruthDataset( transform=transforms.Resize(size), data_path=VALIDATION_DATA_PATH, locations=["scene_city_24"] if config['overfit_test'] else None, input_directions=["S", "E"] if config['overfit_test'] else None, target_directions=["S", "E"] if config['overfit_test'] else None, input_colors=["2500", "6500"] if config['overfit_test'] else None, target_colors=["2500", "6500"] if config['overfit_test'] else None) file = open( "testdataset" + str(config['overfit_test']) + str(size) + '.pickle', 'wb') pickle.dump(test_dataset, file) file.close() print("saved testdataset" + str(config['overfit_test']) + str(size) + '.pickle') test_dataloader = DataLoader(test_dataset, batch_size=config['test_batch_size'], shuffle=config['shuffle_data'], num_workers=config['test_num_workers']) test_dataloaders = {"full": test_dataloader} if config["testing_on_subsets"]: additional_pairing_strategies = [[SameLightColor()], [SameLightDirection()]] #[SameScene()], #[SameScene(), SameLightColor()], #[SameScene(), SameLightDirection()], #[SameLightDirection(), SameLightColor()], #[SameScene(), SameLightDirection(), SameLightColor()]] for pairing_strategies in additional_pairing_strategies: try: file = open( "testdataset" + str(config['overfit_test']) + str(size) + str(pairing_strategies) + '.pickle', 'rb') print("Restoring test dataset " + str(pairing_strategies) + " from pickle file") test_dataset = pickle.load(file) file.close() print("Restored test dataset " + str(pairing_strategies) + " from pickle file") except: test_dataset = InputTargetGroundtruthDataset( transform=transforms.Resize(size), data_path=VALIDATION_DATA_PATH, pairing_strategies=pairing_strategies, locations=["scene_city_24"] if config['overfit_test'] else None, input_directions=["S", "E"] if config['overfit_test'] else None, target_directions=["S", "E"] if config['overfit_test'] else None, input_colors=["2500", "6500"] if config['overfit_test'] else None, target_colors=["2500", "6500"] if config['overfit_test'] else None) file = open( "testdataset" + str(config['overfit_test']) + str(size) + str(pairing_strategies) + '.pickle', 'wb') pickle.dump(test_dataset, file) file.close() print("saved testdataset" + str(config['overfit_test']) + str(size) + str(pairing_strategies) + '.pickle') test_dataloader = DataLoader( test_dataset, batch_size=config['test_batch_size'], shuffle=config['shuffle_data'], num_workers=config['test_num_workers']) test_dataloaders[str(pairing_strategies)] = test_dataloader print( f'Dataset contains {len(train_dataset)} train samples and {len(test_dataset)} test samples.' ) print( f'{config["shown_samples_grid"]} samples will be visualized every {config["testing_frequence"]} batches.' ) print( f'Evaluation will be made every {config["testing_frequence"]} batches on {config["batches_for_testing"]} batches' ) # Configure tensorboard writer = tensorboard.setup_summary_writer(config['name']) tensorboard_process = tensorboard.start_tensorboard_process( ) # TODO: config["tensorboard_port"] # Train loop # Init train scalars (train_generator_loss, train_discriminator_loss, train_score, train_lpips, train_ssim, train_psnr, train_scene_latent_loss_input_gt, train_scene_latent_loss_input_target, train_scene_latent_loss_gt_target, train_light_latent_loss_input_gt, train_light_latent_loss_input_target, train_light_latent_loss_gt_target, train_color_prediction_loss, train_direction_prediction_loss) = (0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.) # Init train loop train_dataloader_iter = iter(train_dataloader) train_batches_counter = 0 print(f'Running for {config["train_duration"]} batches.') last_save_t = 0 # Train loop while train_batches_counter < config['train_duration']: # Store trained model t = time.time() if t - last_save_t > config["checkpoint_period"]: last_save_t = t save_trained(generator, "generator" + config['name'] + str(t)) if config["use_illumination_predicter"]: save_trained( illumination_predicter, "illumination_predicter" + config['name'] + str(t)) if config["use_discriminator"]: save_trained(discriminator, "discriminator" + config['name'] + str(t)) #with torch.autograd.detect_anomaly(): # Load batch if config["debug"]: print('Load batch', get_gpu_memory_map()) with torch.no_grad(): train_batch, train_dataloader_iter = next_batch( train_dataloader_iter, train_dataloader) (input_image, target_image, groundtruth_image, input_color, target_color, groundtruth_color, input_direction, target_direction, groundtruth_direction) = extract_from_batch(train_batch, device) # Generator # Generator: Forward if config["debug"]: print('Generator: Forward', get_gpu_memory_map()) output = generator(input_image, target_image, groundtruth_image) (relit_image, input_light_latent, target_light_latent, groundtruth_light_latent, input_scene_latent, target_scene_latent, groundtruth_scene_latent) = output r = reconstruction_loss(relit_image, groundtruth_image) generator_loss = config['generator_loss_reconstruction_l2_factor'] * r if config["use_illumination_predicter"]: input_illumination = illumination_predicter(input_light_latent) target_illumination = illumination_predicter(target_light_latent) groundtruth_illumination = illumination_predicter( groundtruth_light_latent) c = (1 / 3 * color_prediction_loss(input_illumination[:, 0], input_color) + 1 / 3 * color_prediction_loss(target_illumination[:, 0], target_color) + 1 / 3 * color_prediction_loss( groundtruth_illumination[:, 0], groundtruth_color)) d = (1 / 3 * direction_prediction_loss(input_illumination[:, 1], input_direction) + 1 / 3 * direction_prediction_loss(target_illumination[:, 1], target_direction) + 1 / 3 * direction_prediction_loss( groundtruth_illumination[:, 1], groundtruth_direction)) generator_loss += config['generator_loss_color_l2_factor'] * c generator_loss += config['generator_loss_direction_l2_factor'] * d train_color_prediction_loss += c.item() train_direction_prediction_loss += d.item() train_generator_loss += generator_loss.item() train_score += reconstruction_loss( input_image, groundtruth_image).item() / reconstruction_loss( relit_image, groundtruth_image).item() if "scene_latent" in config["metrics"]: train_scene_latent_loss_input_gt += scene_latent_loss( input_image, groundtruth_image).item() train_scene_latent_loss_input_target += scene_latent_loss( input_image, target_image).item() train_scene_latent_loss_gt_target += scene_latent_loss( target_image, groundtruth_image).item() if "light_latent" in config["metrics"]: train_light_latent_loss_input_gt += light_latent_loss( input_image, groundtruth_image).item() train_light_latent_loss_input_target += light_latent_loss( input_image, target_image).item() train_light_latent_loss_gt_target += light_latent_loss( target_image, groundtruth_image).item() if "LPIPS" in config["metrics"]: train_lpips += lpips_loss(relit_image, groundtruth_image).item() if "SSIM" in config["metrics"]: train_ssim += ssim(relit_image, groundtruth_image).item() if "PSNR" in config["metrics"]: train_psnr += psnr(relit_image, groundtruth_image).item() # Generator: Backward if config["debug"]: print('Generator: Backward', get_gpu_memory_map()) optimizerG.zero_grad() if config["use_discriminator"]: optimizerD.zero_grad() if config["use_discriminator"]: discriminator.zero_grad() generator_loss.backward( ) # use requires_grad = False for speed? Et pour enlever ces zero_grad en double! # Generator: Update parameters if config["debug"]: print('Generator: Update parameters', get_gpu_memory_map()) optimizerG.step() # Discriminator if config["use_discriminator"]: if config["debug"]: print('Discriminator', get_gpu_memory_map()) # Discriminator : Forward output = generator(input_image, target_image, groundtruth_image) (relit_image, input_light_latent, target_light_latent, groundtruth_light_latent, input_scene_latent, target_scene_latent, groundtruth_scene_latent) = output disc_out_fake = discriminator(relit_image) disc_out_real = discriminator(groundtruth_image) discriminator_loss = config[ 'discriminator_loss_gan_factor'] * gan_loss( disc_out_fake, disc_out_real) train_discriminator_loss += discriminator_loss.item() # Discriminator : Backward optimizerD.zero_grad() discriminator.zero_grad() optimizerG.zero_grad() generator.zero_grad() discriminator_loss.backward() generator.zero_grad() optimizerG.zero_grad() # Discriminator : Update parameters optimizerD.step() # Update train_batches_counter train_batches_counter += 1 # If it is time to do so, test and visualize current progress step, modulo = divmod(train_batches_counter, config['testing_frequence']) if modulo == 0: with torch.no_grad(): # Visualize train if config["debug"]: print('Visualize train', get_gpu_memory_map()) write_images( writer=writer, header="Train", step=step, inputs=input_image[:config['shown_samples_grid']], input_light_latents=input_light_latent[:config[ 'shown_samples_grid']], targets=target_image[:config['shown_samples_grid']], target_light_latents=target_light_latent[:config[ 'shown_samples_grid']], groundtruths=groundtruth_image[:config[ 'shown_samples_grid']], groundtruth_light_latents=groundtruth_light_latent[:config[ 'shown_samples_grid']], relits=relit_image[:config['shown_samples_grid']]) write_measures( writer=writer, header="Train", step=step, generator_loss=train_generator_loss / config['testing_frequence'], discriminator_loss=train_discriminator_loss / config['testing_frequence'], score=train_score / config['testing_frequence'], ssim=train_ssim / config['testing_frequence'], lpips=train_lpips / config['testing_frequence'], psnr=train_psnr / config['testing_frequence'], scene_input_gt=train_scene_latent_loss_input_gt / config['testing_frequence'], scene_input_target=train_scene_latent_loss_input_target / config['testing_frequence'], scene_gt_target=train_scene_latent_loss_gt_target / config['testing_frequence'], light_input_gt=train_light_latent_loss_input_gt / config['testing_frequence'], light_input_target=train_light_latent_loss_input_target / config['testing_frequence'], light_gt_target=train_light_latent_loss_gt_target / config['testing_frequence'], color_prediction=train_color_prediction_loss / config['testing_frequence'], direction_prediction=train_direction_prediction_loss / config['testing_frequence']) print('Train', 'Loss:', train_generator_loss / config['testing_frequence'], 'Score:', train_score / config['testing_frequence']) if config["debug_memory"]: print(get_gpu_memory_map()) # del generator_loss # torch.cuda.empty_cache() # print(get_gpu_memory_map()) # Reset train scalars if config["debug"]: print('Reset train scalars', get_gpu_memory_map()) (train_generator_loss, train_discriminator_loss, train_score, train_lpips, train_ssim, train_psnr, train_scene_latent_loss_input_gt, train_scene_latent_loss_input_target, train_scene_latent_loss_gt_target, train_light_latent_loss_input_gt, train_light_latent_loss_input_target, train_light_latent_loss_gt_target, train_color_prediction_loss, train_direction_prediction_loss) = (0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.) # Test loop if config["debug"]: print('Test loop', get_gpu_memory_map()) for header, test_dataloader in test_dataloaders.items(): # Init test scalars if config["debug"]: print('Init test scalars', get_gpu_memory_map()) (test_generator_loss, test_discriminator_loss, test_score, test_lpips, test_ssim, test_psnr, test_scene_latent_loss_input_gt, test_scene_latent_loss_input_target, test_scene_latent_loss_gt_target, test_light_latent_loss_input_gt, test_light_latent_loss_input_target, test_light_latent_loss_gt_target, test_color_prediction_loss, test_direction_prediction_loss) = (0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.) # Init test loop if config["debug"]: print('Init test loop', get_gpu_memory_map()) test_dataloader_iter = iter(test_dataloader) testing_batches_counter = 0 while testing_batches_counter < config[ 'batches_for_testing']: # Load batch if config["debug"]: print('Load batch', get_gpu_memory_map()) test_batch, test_dataloader_iter = next_batch( test_dataloader_iter, test_dataloader) (input_image, target_image, groundtruth_image, input_color, target_color, groundtruth_color, input_direction, target_direction, groundtruth_direction) = extract_from_batch( test_batch, device) # Forward # Generator if config["debug"]: print('Generator', get_gpu_memory_map()) output = generator(input_image, target_image, groundtruth_image) (relit_image, input_light_latent, target_light_latent, groundtruth_light_latent, input_scene_latent, target_scene_latent, groundtruth_scene_latent) = output r = reconstruction_loss(relit_image, groundtruth_image) generator_loss = config[ 'generator_loss_reconstruction_l2_factor'] * r if config["use_illumination_predicter"]: input_illumination = illumination_predicter( input_light_latent) target_illumination = illumination_predicter( target_light_latent) groundtruth_illumination = illumination_predicter( groundtruth_light_latent) c = (1 / 3 * color_prediction_loss( input_illumination[:, 0], input_color) + 1 / 3 * color_prediction_loss( target_illumination[:, 0], target_color) + 1 / 3 * color_prediction_loss( groundtruth_illumination[:, 0], groundtruth_color)) d = (1 / 3 * direction_prediction_loss( input_illumination[:, 1], input_direction) + 1 / 3 * direction_prediction_loss( target_illumination[:, 1], target_direction) + 1 / 3 * direction_prediction_loss( groundtruth_illumination[:, 1], groundtruth_direction)) generator_loss += config[ 'generator_loss_color_l2_factor'] * c generator_loss += config[ 'generator_loss_direction_l2_factor'] * d test_color_prediction_loss += c.item() test_direction_prediction_loss += d.item() test_generator_loss += generator_loss.item() test_score += reconstruction_loss( input_image, groundtruth_image).item() / reconstruction_loss( relit_image, groundtruth_image).item() if "scene_latent" in config["metrics"]: test_scene_latent_loss_input_gt += scene_latent_loss( input_image, groundtruth_image).item() test_scene_latent_loss_input_target += scene_latent_loss( input_image, target_image).item() test_scene_latent_loss_gt_target += scene_latent_loss( target_image, groundtruth_image).item() if "light_latent" in config["metrics"]: test_light_latent_loss_input_gt += light_latent_loss( input_image, groundtruth_image).item() test_light_latent_loss_input_target += light_latent_loss( input_image, target_image).item() test_light_latent_loss_gt_target += light_latent_loss( target_image, groundtruth_image).item() if "LPIPS" in config["metrics"]: test_lpips += lpips_loss(relit_image, groundtruth_image).item() if "SSIM" in config["metrics"]: test_ssim += ssim(relit_image, groundtruth_image).item() if "PSNR" in config["metrics"]: test_psnr += psnr(relit_image, groundtruth_image).item() # Discriminator if config["debug"]: print('Discriminator', get_gpu_memory_map()) if config["use_discriminator"]: disc_out_fake = discriminator(relit_image) disc_out_real = discriminator(groundtruth_image) discriminator_loss = config[ 'discriminator_loss_gan_factor'] * gan_loss( disc_out_fake, disc_out_real) test_discriminator_loss += discriminator_loss.item( ) # Update testing_batches_counter if config["debug"]: print('Update testing_batches_counter', get_gpu_memory_map()) testing_batches_counter += 1 # Visualize test if config["debug"]: print('Visualize test', get_gpu_memory_map()) write_images( writer=writer, header="Test-" + header, step=step, inputs=input_image[:config['shown_samples_grid']], input_light_latents=input_light_latent[:config[ 'shown_samples_grid']], targets=target_image[:config['shown_samples_grid']], target_light_latents=target_light_latent[:config[ 'shown_samples_grid']], groundtruths=groundtruth_image[:config[ 'shown_samples_grid']], groundtruth_light_latents= groundtruth_light_latent[: config['shown_samples_grid']], relits=relit_image[:config['shown_samples_grid']]) write_measures( writer=writer, header="Test-" + header, step=step, generator_loss=test_generator_loss / config['batches_for_testing'], discriminator_loss=test_discriminator_loss / config['batches_for_testing'], score=test_score / config['batches_for_testing'], ssim=test_ssim / config['batches_for_testing'], lpips=test_lpips / config['batches_for_testing'], psnr=test_psnr / config['batches_for_testing'], scene_input_gt=test_scene_latent_loss_input_gt / config['batches_for_testing'], scene_input_target=test_scene_latent_loss_input_target / config['batches_for_testing'], scene_gt_target=test_scene_latent_loss_gt_target / config['batches_for_testing'], light_input_gt=test_light_latent_loss_input_gt / config['batches_for_testing'], light_input_target=test_light_latent_loss_input_target / config['batches_for_testing'], light_gt_target=test_light_latent_loss_gt_target / config['batches_for_testing'], color_prediction=test_color_prediction_loss / config['batches_for_testing'], direction_prediction=test_direction_prediction_loss / config['batches_for_testing']) print('Test-' + header, 'Loss:', test_generator_loss / config['testing_frequence'], 'Score:', test_score / config['testing_frequence']) if config["debug_memory"]: print(get_gpu_memory_map())