Ejemplo n.º 1
0
	def __init__(self, hparams: DictConfig = None, args: Namespace = None, loss_fns = None,):
		super().__init__()

		self.args = args
		self.hparams = hparams
		self.generator, self.discriminator = get_model(self.hparams)
		print(self.generator)
		print(self.discriminator)
		# init loss functions
		self.loss_fns = loss_fns if loss_fns else {'L2': l2_loss, #L2 loss
												   'ADV': GANLoss(hparams.gan_mode), # adversarial Loss
												   'G': l2_loss, # goal achievement loss
												   'GCE': nn.CrossEntropyLoss() }# Goal Cross Entropy loss
		# init loss weights

		self.loss_weights =  {'L2': hparams.w_L2,
							   'ADV': hparams.w_ADV,  # adversarial Loss
							   'G': hparams.w_G,  # goal achievement loss
							   'GCE': hparams.w_GCE  } # Goal Cross Entropy loss

		self.current_batch_idx = -1
		self.plot_val = True


		if self.hparams.batch_size_scheduler:
			self.batch_size = self.hparams.batch_size_scheduler
		else:
			self.batch_size = self.hparams.batch_size
Ejemplo n.º 2
0
    def __init__(self,
                 generator_instantiator,
                 discriminator=None,
                 text_encoder_instantiator=None):
        super().__init__()
        self.generator = generator_instantiator()
        self.generator_running_avg = generator_instantiator()
        self.generator_running_avg.load_state_dict(
            self.generator.state_dict())  # Same initial weights
        for p in self.generator_running_avg.parameters():
            p.requires_grad = False

        self.discriminator = discriminator
        self.criterion_gan = GANLoss(args.loss,
                                     tensor=torch.cuda.FloatTensor).cuda()

        if text_encoder_instantiator is not None:
            self.text_encoder_g = text_encoder_instantiator()
            total_params = 0
            for param in self.text_encoder_g.parameters():
                total_params += param.nelement()
            print('TextEncoder parameters: {:.2f}M'.format(total_params /
                                                           1000000))
            if not args.evaluate:
                if not args.text_train_encoder:
                    # G and D use the same text encoder instance
                    self.text_encoder_d = self.text_encoder_g
                else:
                    # Different instances for G and D
                    self.text_encoder_d = text_encoder_instantiator()
Ejemplo n.º 3
0
    def __init__(self,
                 generator,
                 train_dset,
                 val_dset,
                 cfg: DictConfig = None,
                 loss_fns=None):
        super().__init__()

        self.cfg = cfg
        self.generator = generator

        self.generator.gen()
        # init loss functions
        self.loss_fns = loss_fns if loss_fns else {
            'L2': l2_loss,  # L2 loss
            'ADV': GANLoss(cfg.gan_mode),  # adversarial Loss
            'G': l2_loss,  # goal achievement loss
            'GCE': nn.CrossEntropyLoss()
        }  # Goal Cross Entropy loss
        # init loss weights

        self.loss_weights = {
            'L2': cfg.w_L2,
            'ADV': cfg.w_ADV,  # adversarial Loss
            'G': cfg.w_G,  # goal achievement loss
            'GCE': cfg.w_GCE
        }  # Goal Cross Entropy loss
        self.train_dset = train_dset
        self.val_dset = val_dset

        self.plot_val = True

        if self.cfg.pretraining.batch_size_scheduler:
            self.batch_size = self.cfg.pretraining.batch_size_scheduler
        else:
            self.batch_size = self.cfg.batch_size
optG = get_optimizer(cfg.optimizer.generator)(params=netG.parameters())
optD = get_optimizer(cfg.optimizer.discriminator)(params=netD.parameters())

# set dataset, dataloader
dataset = get_dataset(cfg)
transform = get_transform(cfg)
trainset = dataset(root=dataset_cfg.path,
                   train=True,
                   transform=transform.gan_training)
trainloader = DataLoader(trainset,
                         batch_size=dataset_cfg.batch_size,
                         num_workers=0,
                         shuffle=True)

# set gan loss
gan_loss = GANLoss(gan_mode=train_cfg.loss_type)

# training, visualizing, saving
iters = 0
for epoch in range(cfg.num_epochs):
    for i, data in enumerate(trainloader):
        start_time = time.time()
        real_imgs = data[0].cuda()

        # update discriminator
        for _ in range(train_cfg.args.n_dis):
            optD.zero_grad()
            z = torch.randn(real_imgs.size(0), cfg.generator.args.z_dim).cuda()
            gen_imgs = netG(z).detach()
            real_pred = netD(real_imgs)
            fake_pred = netD(gen_imgs)
optimizerG = torch.optim.Adam(list(generator.parameters()) +
                              list(illumination.parameters()),
                              weight_decay=0)
optimizerD = torch.optim.Adam(discriminator.parameters(), weight_decay=0)
generator.train()
discriminator.train()

print(generator)
print(discriminator)

# Losses
reconstruction_loss = ReconstructionLoss().to(device)
scene_latent_loss = SceneLatentLoss().to(device)
light_latent_loss = LightLatentLoss().to(device)
color_prediction_loss = ColorPredictionLoss().to(device)
gan_loss = GANLoss().to(device)
fool_gan_loss = FoolGANLoss().to(device)

# Configure dataloader
train_dataset = InputTargetGroundtruthDataset(
    transform=transforms.Resize(SIZE),
    data_path=TRAIN_DATA_PATH,
    locations=['scene_abandonned_city_54'],
    input_directions=["S"],
    target_directions=["N"],
    input_colors=["2500"],
    target_colors=["6500"])
test_dataset = InputTargetGroundtruthDataset(transform=transforms.Resize(SIZE),
                                             data_path=VALIDATION_DATA_PATH,
                                             input_directions=["S"],
                                             target_directions=["N"],
Ejemplo n.º 6
0
# set optimizers
optE = get_optimizer(cfg.optimizer.encoder)(params=netE.parameters())
# set loss function
image_criterion = RecLoss(cfg.train.image_loss.type)
latent_criterion = RecLoss(cfg.train.latent_loss.type)
# set weights
w_image = cfg.train.image_loss.weight
w_latent = cfg.train.latent_loss.weight

if use_gan_loss:
    netD = get_discriminator(cfg, cfg.discriminator)
    netD.cuda()
    netD.train()
    optD = get_optimizer(cfg.optimizer.discriminator)(params=netD.parameters())
    gan_criterion = GANLoss(gan_mode=cfg.train.adv_loss.type)
    w_adv = cfg.train.adv_loss.weight

for iters in range(cfg.max_iters):
    start_time = time.time()

    if use_gan_loss:
        optD.zero_grad()
        z = torch.randn(dataset_cfg.batch_size,
                        cfg.generator.args.z_dim).cuda()

        x_real = netG(z).detach()

        if cfg.train.denoise:
            noise = torch.randn_like(x_real) * 0.1
            x_in = x_real + noise
Ejemplo n.º 7
0
def main(config):
    # Device to use
    device = setup_device(config["gpus"])

    # Configure training objects
    # Generator
    model_name = config["model"]
    generator = get_generator_model(model_name)().to(device)
    weight_decay = config["L2_regularization_generator"]
    if config["use_illumination_predicter"]:
        light_latent_size = get_light_latent_size(model_name)
        illumination_predicter = IlluminationPredicter(
            in_size=light_latent_size).to(device)
        optimizerG = torch.optim.Adam(
            list(generator.parameters()) +
            list(illumination_predicter.parameters()),
            weight_decay=weight_decay)
    else:
        optimizerG = torch.optim.Adam(generator.parameters(),
                                      weight_decay=weight_decay)
    # Discriminator
    if config["use_discriminator"]:
        if config["discriminator_everything_as_input"]:
            raise NotImplementedError  # TODO
        else:
            discriminator = NLayerDiscriminator().to(device)
        optimizerD = torch.optim.Adam(
            discriminator.parameters(),
            weight_decay=config["L2_regularization_discriminator"])

    # Losses
    reconstruction_loss = ReconstructionLoss().to(device)
    if config["use_illumination_predicter"]:
        color_prediction_loss = ColorPredictionLoss().to(device)
        direction_prediction_loss = DirectionPredictionLoss().to(device)
    if config["use_discriminator"]:
        gan_loss = GANLoss().to(device)
        fool_gan_loss = FoolGANLoss().to(device)

    # Metrics
    if "scene_latent" in config["metrics"]:
        scene_latent_loss = SceneLatentLoss().to(device)
    if "light_latent" in config["metrics"]:
        light_latent_loss = LightLatentLoss().to(device)
    if "LPIPS" in config["metrics"]:
        lpips_loss = LPIPS(
            net_type=
            'alex',  # choose a network type from ['alex', 'squeeze', 'vgg']
            version='0.1'  # Currently, v0.1 is supported
        ).to(device)

    # Configure dataloader
    size = config['image_resize']
    # train
    try:
        file = open(
            'traindataset' + str(config['overfit_test']) + str(size) +
            '.pickle', 'rb')
        print("Restoring train dataset from pickle file")
        train_dataset = pickle.load(file)
        file.close()
        print("Restored train dataset from pickle file")
    except:
        train_dataset = InputTargetGroundtruthDataset(
            transform=transforms.Resize(size),
            data_path=TRAIN_DATA_PATH,
            locations=['scene_abandonned_city_54']
            if config['overfit_test'] else None,
            input_directions=["S", "E"] if config['overfit_test'] else None,
            target_directions=["S", "E"] if config['overfit_test'] else None,
            input_colors=["2500", "6500"] if config['overfit_test'] else None,
            target_colors=["2500", "6500"] if config['overfit_test'] else None)
        file = open(
            "traindataset" + str(config['overfit_test']) + str(size) +
            '.pickle', 'wb')
        pickle.dump(train_dataset, file)
        file.close()
        print("saved traindataset" + str(config['overfit_test']) + str(size) +
              '.pickle')
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config['train_batch_size'],
                                  shuffle=config['shuffle_data'],
                                  num_workers=config['train_num_workers'])
    # test
    try:
        file = open(
            "testdataset" + str(config['overfit_test']) + str(size) +
            '.pickle', 'rb')
        print("Restoring full test dataset from pickle file")
        test_dataset = pickle.load(file)
        file.close()
        print("Restored full test dataset from pickle file")
    except:
        test_dataset = InputTargetGroundtruthDataset(
            transform=transforms.Resize(size),
            data_path=VALIDATION_DATA_PATH,
            locations=["scene_city_24"] if config['overfit_test'] else None,
            input_directions=["S", "E"] if config['overfit_test'] else None,
            target_directions=["S", "E"] if config['overfit_test'] else None,
            input_colors=["2500", "6500"] if config['overfit_test'] else None,
            target_colors=["2500", "6500"] if config['overfit_test'] else None)
        file = open(
            "testdataset" + str(config['overfit_test']) + str(size) +
            '.pickle', 'wb')
        pickle.dump(test_dataset, file)
        file.close()
        print("saved testdataset" + str(config['overfit_test']) + str(size) +
              '.pickle')
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=config['test_batch_size'],
                                 shuffle=config['shuffle_data'],
                                 num_workers=config['test_num_workers'])
    test_dataloaders = {"full": test_dataloader}
    if config["testing_on_subsets"]:
        additional_pairing_strategies = [[SameLightColor()],
                                         [SameLightDirection()]]
        #[SameScene()],
        #[SameScene(), SameLightColor()],
        #[SameScene(), SameLightDirection()],
        #[SameLightDirection(), SameLightColor()],
        #[SameScene(), SameLightDirection(), SameLightColor()]]
        for pairing_strategies in additional_pairing_strategies:
            try:
                file = open(
                    "testdataset" + str(config['overfit_test']) + str(size) +
                    str(pairing_strategies) + '.pickle', 'rb')
                print("Restoring test dataset " + str(pairing_strategies) +
                      " from pickle file")
                test_dataset = pickle.load(file)
                file.close()
                print("Restored test dataset " + str(pairing_strategies) +
                      " from pickle file")
            except:
                test_dataset = InputTargetGroundtruthDataset(
                    transform=transforms.Resize(size),
                    data_path=VALIDATION_DATA_PATH,
                    pairing_strategies=pairing_strategies,
                    locations=["scene_city_24"]
                    if config['overfit_test'] else None,
                    input_directions=["S", "E"]
                    if config['overfit_test'] else None,
                    target_directions=["S", "E"]
                    if config['overfit_test'] else None,
                    input_colors=["2500", "6500"]
                    if config['overfit_test'] else None,
                    target_colors=["2500", "6500"]
                    if config['overfit_test'] else None)
                file = open(
                    "testdataset" + str(config['overfit_test']) + str(size) +
                    str(pairing_strategies) + '.pickle', 'wb')
                pickle.dump(test_dataset, file)
                file.close()
                print("saved testdataset" + str(config['overfit_test']) +
                      str(size) + str(pairing_strategies) + '.pickle')
            test_dataloader = DataLoader(
                test_dataset,
                batch_size=config['test_batch_size'],
                shuffle=config['shuffle_data'],
                num_workers=config['test_num_workers'])
            test_dataloaders[str(pairing_strategies)] = test_dataloader
    print(
        f'Dataset contains {len(train_dataset)} train samples and {len(test_dataset)} test samples.'
    )
    print(
        f'{config["shown_samples_grid"]} samples will be visualized every {config["testing_frequence"]} batches.'
    )
    print(
        f'Evaluation will be made every {config["testing_frequence"]} batches on {config["batches_for_testing"]} batches'
    )

    # Configure tensorboard
    writer = tensorboard.setup_summary_writer(config['name'])
    tensorboard_process = tensorboard.start_tensorboard_process(
    )  # TODO: config["tensorboard_port"]

    # Train loop

    # Init train scalars
    (train_generator_loss, train_discriminator_loss, train_score, train_lpips,
     train_ssim, train_psnr, train_scene_latent_loss_input_gt,
     train_scene_latent_loss_input_target, train_scene_latent_loss_gt_target,
     train_light_latent_loss_input_gt, train_light_latent_loss_input_target,
     train_light_latent_loss_gt_target, train_color_prediction_loss,
     train_direction_prediction_loss) = (0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                         0., 0., 0., 0., 0.)

    # Init train loop
    train_dataloader_iter = iter(train_dataloader)
    train_batches_counter = 0
    print(f'Running for {config["train_duration"]} batches.')

    last_save_t = 0
    # Train loop
    while train_batches_counter < config['train_duration']:
        # Store trained model
        t = time.time()
        if t - last_save_t > config["checkpoint_period"]:
            last_save_t = t
            save_trained(generator, "generator" + config['name'] + str(t))
            if config["use_illumination_predicter"]:
                save_trained(
                    illumination_predicter,
                    "illumination_predicter" + config['name'] + str(t))
            if config["use_discriminator"]:
                save_trained(discriminator,
                             "discriminator" + config['name'] + str(t))

        #with torch.autograd.detect_anomaly():
        # Load batch
        if config["debug"]: print('Load batch', get_gpu_memory_map())
        with torch.no_grad():
            train_batch, train_dataloader_iter = next_batch(
                train_dataloader_iter, train_dataloader)
            (input_image, target_image, groundtruth_image, input_color,
             target_color, groundtruth_color, input_direction,
             target_direction,
             groundtruth_direction) = extract_from_batch(train_batch, device)

        # Generator
        # Generator: Forward
        if config["debug"]:
            print('Generator: Forward', get_gpu_memory_map())
        output = generator(input_image, target_image, groundtruth_image)
        (relit_image, input_light_latent, target_light_latent,
         groundtruth_light_latent, input_scene_latent, target_scene_latent,
         groundtruth_scene_latent) = output
        r = reconstruction_loss(relit_image, groundtruth_image)
        generator_loss = config['generator_loss_reconstruction_l2_factor'] * r
        if config["use_illumination_predicter"]:
            input_illumination = illumination_predicter(input_light_latent)
            target_illumination = illumination_predicter(target_light_latent)
            groundtruth_illumination = illumination_predicter(
                groundtruth_light_latent)
            c = (1 / 3 *
                 color_prediction_loss(input_illumination[:, 0], input_color) +
                 1 / 3 *
                 color_prediction_loss(target_illumination[:, 0], target_color)
                 + 1 / 3 * color_prediction_loss(
                     groundtruth_illumination[:, 0], groundtruth_color))
            d = (1 / 3 * direction_prediction_loss(input_illumination[:, 1],
                                                   input_direction) +
                 1 / 3 * direction_prediction_loss(target_illumination[:, 1],
                                                   target_direction) +
                 1 / 3 * direction_prediction_loss(
                     groundtruth_illumination[:, 1], groundtruth_direction))
            generator_loss += config['generator_loss_color_l2_factor'] * c
            generator_loss += config['generator_loss_direction_l2_factor'] * d
            train_color_prediction_loss += c.item()
            train_direction_prediction_loss += d.item()
        train_generator_loss += generator_loss.item()
        train_score += reconstruction_loss(
            input_image, groundtruth_image).item() / reconstruction_loss(
                relit_image, groundtruth_image).item()
        if "scene_latent" in config["metrics"]:
            train_scene_latent_loss_input_gt += scene_latent_loss(
                input_image, groundtruth_image).item()
            train_scene_latent_loss_input_target += scene_latent_loss(
                input_image, target_image).item()
            train_scene_latent_loss_gt_target += scene_latent_loss(
                target_image, groundtruth_image).item()
        if "light_latent" in config["metrics"]:
            train_light_latent_loss_input_gt += light_latent_loss(
                input_image, groundtruth_image).item()
            train_light_latent_loss_input_target += light_latent_loss(
                input_image, target_image).item()
            train_light_latent_loss_gt_target += light_latent_loss(
                target_image, groundtruth_image).item()
        if "LPIPS" in config["metrics"]:
            train_lpips += lpips_loss(relit_image, groundtruth_image).item()
        if "SSIM" in config["metrics"]:
            train_ssim += ssim(relit_image, groundtruth_image).item()
        if "PSNR" in config["metrics"]:
            train_psnr += psnr(relit_image, groundtruth_image).item()

        # Generator: Backward
        if config["debug"]:
            print('Generator: Backward', get_gpu_memory_map())
        optimizerG.zero_grad()
        if config["use_discriminator"]:
            optimizerD.zero_grad()
        if config["use_discriminator"]:
            discriminator.zero_grad()
        generator_loss.backward(
        )  # use requires_grad = False for speed? Et pour enlever ces zero_grad en double!
        # Generator: Update parameters
        if config["debug"]:
            print('Generator: Update parameters', get_gpu_memory_map())
        optimizerG.step()

        # Discriminator
        if config["use_discriminator"]:
            if config["debug"]:
                print('Discriminator', get_gpu_memory_map())
            # Discriminator : Forward
            output = generator(input_image, target_image, groundtruth_image)
            (relit_image, input_light_latent, target_light_latent,
             groundtruth_light_latent, input_scene_latent, target_scene_latent,
             groundtruth_scene_latent) = output
            disc_out_fake = discriminator(relit_image)
            disc_out_real = discriminator(groundtruth_image)
            discriminator_loss = config[
                'discriminator_loss_gan_factor'] * gan_loss(
                    disc_out_fake, disc_out_real)
            train_discriminator_loss += discriminator_loss.item()
            # Discriminator : Backward
            optimizerD.zero_grad()
            discriminator.zero_grad()
            optimizerG.zero_grad()
            generator.zero_grad()
            discriminator_loss.backward()
            generator.zero_grad()
            optimizerG.zero_grad()
            # Discriminator : Update parameters
            optimizerD.step()

        # Update train_batches_counter
        train_batches_counter += 1

        # If it is time to do so, test and visualize current progress
        step, modulo = divmod(train_batches_counter,
                              config['testing_frequence'])
        if modulo == 0:
            with torch.no_grad():

                # Visualize train
                if config["debug"]:
                    print('Visualize train', get_gpu_memory_map())
                write_images(
                    writer=writer,
                    header="Train",
                    step=step,
                    inputs=input_image[:config['shown_samples_grid']],
                    input_light_latents=input_light_latent[:config[
                        'shown_samples_grid']],
                    targets=target_image[:config['shown_samples_grid']],
                    target_light_latents=target_light_latent[:config[
                        'shown_samples_grid']],
                    groundtruths=groundtruth_image[:config[
                        'shown_samples_grid']],
                    groundtruth_light_latents=groundtruth_light_latent[:config[
                        'shown_samples_grid']],
                    relits=relit_image[:config['shown_samples_grid']])
                write_measures(
                    writer=writer,
                    header="Train",
                    step=step,
                    generator_loss=train_generator_loss /
                    config['testing_frequence'],
                    discriminator_loss=train_discriminator_loss /
                    config['testing_frequence'],
                    score=train_score / config['testing_frequence'],
                    ssim=train_ssim / config['testing_frequence'],
                    lpips=train_lpips / config['testing_frequence'],
                    psnr=train_psnr / config['testing_frequence'],
                    scene_input_gt=train_scene_latent_loss_input_gt /
                    config['testing_frequence'],
                    scene_input_target=train_scene_latent_loss_input_target /
                    config['testing_frequence'],
                    scene_gt_target=train_scene_latent_loss_gt_target /
                    config['testing_frequence'],
                    light_input_gt=train_light_latent_loss_input_gt /
                    config['testing_frequence'],
                    light_input_target=train_light_latent_loss_input_target /
                    config['testing_frequence'],
                    light_gt_target=train_light_latent_loss_gt_target /
                    config['testing_frequence'],
                    color_prediction=train_color_prediction_loss /
                    config['testing_frequence'],
                    direction_prediction=train_direction_prediction_loss /
                    config['testing_frequence'])
                print('Train', 'Loss:',
                      train_generator_loss / config['testing_frequence'],
                      'Score:', train_score / config['testing_frequence'])
                if config["debug_memory"]:
                    print(get_gpu_memory_map())
                    # del generator_loss
                    # torch.cuda.empty_cache()
                    # print(get_gpu_memory_map())

                # Reset train scalars
                if config["debug"]:
                    print('Reset train scalars', get_gpu_memory_map())
                (train_generator_loss, train_discriminator_loss, train_score,
                 train_lpips, train_ssim, train_psnr,
                 train_scene_latent_loss_input_gt,
                 train_scene_latent_loss_input_target,
                 train_scene_latent_loss_gt_target,
                 train_light_latent_loss_input_gt,
                 train_light_latent_loss_input_target,
                 train_light_latent_loss_gt_target,
                 train_color_prediction_loss,
                 train_direction_prediction_loss) = (0., 0., 0., 0., 0., 0.,
                                                     0., 0., 0., 0., 0., 0.,
                                                     0., 0.)

                # Test loop

                if config["debug"]: print('Test loop', get_gpu_memory_map())
                for header, test_dataloader in test_dataloaders.items():

                    # Init test scalars
                    if config["debug"]:
                        print('Init test scalars', get_gpu_memory_map())
                    (test_generator_loss, test_discriminator_loss, test_score,
                     test_lpips, test_ssim, test_psnr,
                     test_scene_latent_loss_input_gt,
                     test_scene_latent_loss_input_target,
                     test_scene_latent_loss_gt_target,
                     test_light_latent_loss_input_gt,
                     test_light_latent_loss_input_target,
                     test_light_latent_loss_gt_target,
                     test_color_prediction_loss,
                     test_direction_prediction_loss) = (0., 0., 0., 0., 0., 0.,
                                                        0., 0., 0., 0., 0., 0.,
                                                        0., 0.)

                    # Init test loop
                    if config["debug"]:
                        print('Init test loop', get_gpu_memory_map())
                    test_dataloader_iter = iter(test_dataloader)
                    testing_batches_counter = 0

                    while testing_batches_counter < config[
                            'batches_for_testing']:

                        # Load batch
                        if config["debug"]:
                            print('Load batch', get_gpu_memory_map())
                        test_batch, test_dataloader_iter = next_batch(
                            test_dataloader_iter, test_dataloader)
                        (input_image, target_image, groundtruth_image,
                         input_color, target_color, groundtruth_color,
                         input_direction, target_direction,
                         groundtruth_direction) = extract_from_batch(
                             test_batch, device)

                        # Forward

                        # Generator
                        if config["debug"]:
                            print('Generator', get_gpu_memory_map())
                        output = generator(input_image, target_image,
                                           groundtruth_image)
                        (relit_image, input_light_latent, target_light_latent,
                         groundtruth_light_latent, input_scene_latent,
                         target_scene_latent,
                         groundtruth_scene_latent) = output
                        r = reconstruction_loss(relit_image, groundtruth_image)
                        generator_loss = config[
                            'generator_loss_reconstruction_l2_factor'] * r
                        if config["use_illumination_predicter"]:
                            input_illumination = illumination_predicter(
                                input_light_latent)
                            target_illumination = illumination_predicter(
                                target_light_latent)
                            groundtruth_illumination = illumination_predicter(
                                groundtruth_light_latent)
                            c = (1 / 3 * color_prediction_loss(
                                input_illumination[:, 0], input_color) +
                                 1 / 3 * color_prediction_loss(
                                     target_illumination[:, 0], target_color) +
                                 1 / 3 * color_prediction_loss(
                                     groundtruth_illumination[:, 0],
                                     groundtruth_color))
                            d = (1 / 3 * direction_prediction_loss(
                                input_illumination[:, 1], input_direction) +
                                 1 / 3 * direction_prediction_loss(
                                     target_illumination[:, 1],
                                     target_direction) +
                                 1 / 3 * direction_prediction_loss(
                                     groundtruth_illumination[:, 1],
                                     groundtruth_direction))
                            generator_loss += config[
                                'generator_loss_color_l2_factor'] * c
                            generator_loss += config[
                                'generator_loss_direction_l2_factor'] * d
                            test_color_prediction_loss += c.item()
                            test_direction_prediction_loss += d.item()
                        test_generator_loss += generator_loss.item()
                        test_score += reconstruction_loss(
                            input_image,
                            groundtruth_image).item() / reconstruction_loss(
                                relit_image, groundtruth_image).item()
                        if "scene_latent" in config["metrics"]:
                            test_scene_latent_loss_input_gt += scene_latent_loss(
                                input_image, groundtruth_image).item()
                            test_scene_latent_loss_input_target += scene_latent_loss(
                                input_image, target_image).item()
                            test_scene_latent_loss_gt_target += scene_latent_loss(
                                target_image, groundtruth_image).item()
                        if "light_latent" in config["metrics"]:
                            test_light_latent_loss_input_gt += light_latent_loss(
                                input_image, groundtruth_image).item()
                            test_light_latent_loss_input_target += light_latent_loss(
                                input_image, target_image).item()
                            test_light_latent_loss_gt_target += light_latent_loss(
                                target_image, groundtruth_image).item()
                        if "LPIPS" in config["metrics"]:
                            test_lpips += lpips_loss(relit_image,
                                                     groundtruth_image).item()
                        if "SSIM" in config["metrics"]:
                            test_ssim += ssim(relit_image,
                                              groundtruth_image).item()
                        if "PSNR" in config["metrics"]:
                            test_psnr += psnr(relit_image,
                                              groundtruth_image).item()

                        # Discriminator
                        if config["debug"]:
                            print('Discriminator', get_gpu_memory_map())
                        if config["use_discriminator"]:
                            disc_out_fake = discriminator(relit_image)
                            disc_out_real = discriminator(groundtruth_image)
                            discriminator_loss = config[
                                'discriminator_loss_gan_factor'] * gan_loss(
                                    disc_out_fake, disc_out_real)
                            test_discriminator_loss += discriminator_loss.item(
                            )

                        # Update testing_batches_counter
                        if config["debug"]:
                            print('Update testing_batches_counter',
                                  get_gpu_memory_map())
                        testing_batches_counter += 1

                    # Visualize test
                    if config["debug"]:
                        print('Visualize test', get_gpu_memory_map())
                    write_images(
                        writer=writer,
                        header="Test-" + header,
                        step=step,
                        inputs=input_image[:config['shown_samples_grid']],
                        input_light_latents=input_light_latent[:config[
                            'shown_samples_grid']],
                        targets=target_image[:config['shown_samples_grid']],
                        target_light_latents=target_light_latent[:config[
                            'shown_samples_grid']],
                        groundtruths=groundtruth_image[:config[
                            'shown_samples_grid']],
                        groundtruth_light_latents=
                        groundtruth_light_latent[:
                                                 config['shown_samples_grid']],
                        relits=relit_image[:config['shown_samples_grid']])
                    write_measures(
                        writer=writer,
                        header="Test-" + header,
                        step=step,
                        generator_loss=test_generator_loss /
                        config['batches_for_testing'],
                        discriminator_loss=test_discriminator_loss /
                        config['batches_for_testing'],
                        score=test_score / config['batches_for_testing'],
                        ssim=test_ssim / config['batches_for_testing'],
                        lpips=test_lpips / config['batches_for_testing'],
                        psnr=test_psnr / config['batches_for_testing'],
                        scene_input_gt=test_scene_latent_loss_input_gt /
                        config['batches_for_testing'],
                        scene_input_target=test_scene_latent_loss_input_target
                        / config['batches_for_testing'],
                        scene_gt_target=test_scene_latent_loss_gt_target /
                        config['batches_for_testing'],
                        light_input_gt=test_light_latent_loss_input_gt /
                        config['batches_for_testing'],
                        light_input_target=test_light_latent_loss_input_target
                        / config['batches_for_testing'],
                        light_gt_target=test_light_latent_loss_gt_target /
                        config['batches_for_testing'],
                        color_prediction=test_color_prediction_loss /
                        config['batches_for_testing'],
                        direction_prediction=test_direction_prediction_loss /
                        config['batches_for_testing'])
                    print('Test-' + header, 'Loss:',
                          test_generator_loss / config['testing_frequence'],
                          'Score:', test_score / config['testing_frequence'])

                    if config["debug_memory"]:
                        print(get_gpu_memory_map())