Пример #1
0
def lerp_img_evaluation(hr_img, lerp_img):
    lerp_img_y = lerp_img.split()[0]
    hr_img_y = hr_img.split()[0]
    img_to_tensor = ToTensor()
    lerp_img_y_tensor = img_to_tensor(lerp_img_y).view(1, -1,
                                                       lerp_img_y.size[1],
                                                       lerp_img_y.size[0])
    hr_img_y_tensor = img_to_tensor(hr_img_y).view(1, -1, hr_img_y.size[1],
                                                   hr_img_y.size[0])

    psnr = metrics.psnr(lerp_img_y_tensor, hr_img_y_tensor)
    nrmse = metrics.nrmse(lerp_img_y_tensor, hr_img_y_tensor)
    ssim = metrics.ssim(lerp_img_y_tensor, hr_img_y_tensor)
    return psnr, nrmse, ssim
Пример #2
0
    def validating(self, model, dataset):
        """
          input:
            model: (object) pytorch model
            batch_size: (int)
            dataset : (object) dataset
          return [val_mse, val_loss]
        """
        args = self.args
        """
        metrics
        """
        val_loss, val_psnr, val_nrmse, val_ssim = 0, 0, 0, 0
        data_loader = DataLoader(dataset=dataset,
                                 batch_size=args.valbatch_size,
                                 num_workers=args.threads,
                                 shuffle=False)
        batch_iterator = iter(data_loader)
        steps = len(dataset) // args.valbatch_size
        #        model.eval()
        start = time.time()
        for step in range(steps):
            x, y = next(batch_iterator)
            x = x.to(self.device)
            y = y.to(self.device)
            # calculate pixel accuracy of generator
            gen_y = model(x)
            """
            metrics
            """
            val_loss += F.mse_loss(gen_y, y).item()
            val_psnr += metrics.psnr(gen_y.data, y.data)
            val_nrmse += metrics.nrmse(gen_y.data, y.data)
            val_ssim += metrics.ssim(gen_y.data, y.data)
#            val_vifp += metrics.vifp(gen_y.data, y.data)

        _time = time.time() - start
        nb_samples = steps * args.valbatch_size
        """
        metrics
        """
        val_log = [
            val_loss / steps, val_psnr / steps, val_nrmse / steps,
            val_ssim / steps, _time, nb_samples / _time
        ]
        self.val_log = [round(x, 3) for x in val_log]
Пример #3
0
    def Y_to_RGB(self):
        args = self.args
        img_path = self.img_path
        model = self.model

        img_to_tensor = ToTensor()
        hr_img = Image.open(img_path).convert('YCbCr')
        hr_img_y = hr_img.split()[0]
        hr_img_y_tensor = img_to_tensor(hr_img_y).view(1, -1, hr_img_y.size[1],
                                                       hr_img_y.size[0])
        hr_img_Cb = hr_img.split()[1]
        hr_img_Cr = hr_img.split()[2]

        if args.interpolation:
            args.upscale_factor = 1

        lr_img_y = TF.resize(hr_img_y,
                             (hr_img_y.size[0] // args.upscale_factor,
                              hr_img_y.size[1] // args.upscale_factor))
        lr_img_y_tensor = img_to_tensor(lr_img_y).view(1, -1, lr_img_y.size[1],
                                                       lr_img_y.size[0])
        input = lr_img_y_tensor

        if args.cuda:
            model = model.cuda()
            input = input.cuda()
        sr_img_y_tensor = model(input)
        sr_img_y_tensor = sr_img_y_tensor.cpu()
        """
        metrics
        """
        psnr = metrics.psnr(sr_img_y_tensor, hr_img_y_tensor)
        nrmse = metrics.nrmse(sr_img_y_tensor, hr_img_y_tensor)
        ssim = metrics.ssim(sr_img_y_tensor, hr_img_y_tensor)

        sr_img_y = sr_img_y_tensor[0].detach().numpy()
        sr_img_y *= 255.0
        sr_img_y = sr_img_y.clip(0, 255)
        sr_img_y = Image.fromarray(np.uint8(sr_img_y[0]), mode='L')
        sr_img_Cb = hr_img_Cb
        sr_img_Cr = hr_img_Cr
        sr_img = Image.merge('YCbCr',
                             [sr_img_y, sr_img_Cb, sr_img_Cr]).convert('RGB')

        return sr_img, (psnr, ssim, nrmse)
Пример #4
0
    def define_graph(self):
        """
        Set up the model graph
        """
        with tf.name_scope('data'):
            self.ref_image = tf.placeholder(tf.float32, shape=[None,128,128,3], name='ref_image')
            self.multi_plane = tf.placeholder(tf.float32, shape=[None,128,128,3*c.NUM_PLANE])
            self.gt = tf.placeholder(tf.float32,shape=[None,128,128,3], name='gt')

        self.summaries=[]
        
        with tf.name_scope('predection'):
            def prediction(ref_image,multi_plane):
                net_in = tf.concat([ref_image,multi_plane],axis=-1)

                conv1_1 = conv_block(net_in,64,3,1)
                conv1_2 = conv_block(conv1_1,128,3,2)

                conv2_1 = conv_block(conv1_2,128,3,1)
                conv2_2 = conv_block(conv2_1,256,3,2)

                conv3_1 = conv_block(conv2_2,256,3,1)
                conv3_2 = conv_block(conv3_1,256,3,1)
                conv3_3 = conv_block(conv3_2,512,3,2)

                # weight3_1 = tf.Variable(tf.random_normal([3, 3, 512]))
                # weight3_2 = tf.Variable(tf.random_normal([3, 3, 512]))
                # weight3_3 = tf.Variable(tf.random_normal([3, 3, 512]))

                # conv4_1 = tf.nn.dilation2d(conv3_3,weight3_1,[1,1,1,1],[1,2,2,1],'SAME')
                # conv4_2 = tf.nn.dilation2d(conv4_1,weight3_2,[1,1,1,1],[1,2,2,1],'SAME')
                # conv4_3 = tf.nn.dilation2d(conv4_2,weight3_3,[1,1,1,1],[1,2,2,1],'SAME')

                conv4_1 = tf.layers.conv2d(conv3_3,512,(3,3),(1,1),'SAME',dilation_rate=(2,2))
                conv4_2 = tf.layers.conv2d(conv4_1,512,(3,3),(1,1),'SAME',dilation_rate=(2,2))
                conv4_3 = tf.layers.conv2d(conv4_2,512,(3,3),(1,1),'SAME',dilation_rate=(2,2))

                conv5_1 = deconv_block(tf.concat([conv4_3,conv3_3],axis=-1),256,4,2)
                conv5_2 = conv_block(conv5_1,256,3,1)
                conv5_3 = conv_block(conv5_2,256,3,1)

                conv6_1 = deconv_block(tf.concat([conv5_3,conv2_2],axis=-1),128,4,2)
                conv6_2 = conv_block(conv6_1,128,3,1)
                
                conv7_1 = deconv_block(tf.concat([conv6_2,conv1_2],axis=-1),64,4,2)
                conv7_2 = conv_block(conv7_1,64,3,1)
                conv7_3 = tf.layers.conv2d(conv7_2,62,(1,1),(1,1),'SAME')
                conv7_3 = tf.nn.tanh(conv7_3)

                blending_weights, alpha_images = tf.split(conv7_3,[c.NUM_PLANE,c.NUM_PLANE],axis=-1)
                blending_weights = tensor_norm(blending_weights)
                #alpha_images = tensor_norm(alpha_images)
                alpha_images = tf.nn.softmax(alpha_images,axis=-1)
               
                feature_maps = {
                    'conv1_1':conv1_1,
                    'conv1_2':conv1_2,
                    'conv2_1':conv2_1,
                    'conv2_2':conv2_2,
                    'conv3_1':conv3_1,
                    'conv3_2':conv3_2,
                    'conv3_3':conv3_3,
                    'conv4_1':conv4_1,
                    'conv4_2':conv4_2,
                    'conv4_3':conv4_3,
                    'conv5_1':conv5_1,
                    'conv6_1':conv6_1,
                    'conv6_2':conv6_2,
                    'conv7_1':conv7_1,
                    'conv7_2':conv7_2,
                    'conv7_3':conv7_3
                }

                return blending_weights, alpha_images, feature_maps
            
            
            self.blending_weights, self.alpha_images, self.feature_maps = prediction(self.ref_image,self.multi_plane)
            self.color_images = []
            for i in range(c.NUM_PLANE):
                tmp_weights = tf.expand_dims(self.blending_weights[:,:,:,i],axis=-1)
                #tmp_weights = self.blending_weights[:,:,:,i]
                self.color_images.append(
                    tf.multiply(tmp_weights,self.ref_image) + 
                    tf.multiply(1-tmp_weights,self.multi_plane[:,:,:,3*i:3*(i+1)]))
            
            self.preds = []
            for i in range(c.NUM_PLANE):
                tmp_alpha = tf.expand_dims(self.alpha_images[:,:,:,i],axis=-1)
                self.preds.append(tf.multiply(tmp_alpha, self.color_images[i]))
            self.preds = tf.accumulate_n(self.preds)
            #self.preds = inception_model(self.preds,6)

        with tf.name_scope('train'):
            self.loss = VGG_loss(self.preds,self.gt)
            self.global_step = tf.Variable(0, trainable=False)
            self.optimizer = tf.train.AdamOptimizer(learning_rate=c.LRATE, name='optimizer')
            self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step, name='train_op')
            loss_summary = tf.summary.scalar('train_loss', self.loss)
            self.summaries.append(loss_summary)

        with tf.name_scope('error'):
            self.psnr = psnr(self.preds,self.gt)
            self.sharpdiff = sharp_diff(self.preds,self.gt)
            self.ssim = ssim(self.preds, self.gt)
            summary_psnr = tf.summary.scalar('train_PSNR',self.psnr)
            summary_sharpdiff = tf.summary.scalar('train_SharpDiff',self.sharpdiff)
            summary_ssim = tf.summary.scalar('trian_ssim',self.ssim)
            self.summaries += [summary_psnr, summary_sharpdiff, summary_ssim]
        self.summaries = tf.summary.merge(self.summaries)
Пример #5
0
    def evaluating(self, model, dataset, split):
        """
        Evaluate overall performance of the model
          input:
            model: (object) pytorch model
            dataset: (object) dataset
            split: (str) split of dataset in ['train', 'val', 'test']
          return [overall_accuracy, precision, recall, f1-score, jaccard, kappa]
        """
        args = self.args
        #        oa, precision, recall, f1, jac, kappa = 0, 0, 0, 0, 0, 0
        """
        metrics
        """
        #        psnr, nrmse, ssim, vifp, fsim
        psnr, nrmse, ssim = 0, 0, 0
        model.eval()
        data_loader = DataLoader(dataset,
                                 args.evalbatch_size,
                                 num_workers=4,
                                 shuffle=False)
        batch_iterator = iter(data_loader)
        steps = len(dataset) // args.evalbatch_size

        start = time.time()
        for step in range(steps):
            x, y = next(batch_iterator)
            if args.cuda:
                x = x.cuda()
                y = y.cuda()
            # calculate pixel accuracy of generator
            """
            metrics
            """
            gen_y = model(x)
            psnr += metrics.psnr(gen_y, y)
            nrmse += metrics.nrmse(gen_y, y)
            ssim += metrics.ssim(gen_y, y)
#            vifp += metrics.vifp(gen_y.data, y.data)
        _time = time.time() - start

        if not os.path.exists(os.path.join(Logs_DIR, 'statistic')):
            os.makedirs(os.path.join(Logs_DIR, 'statistic'))

        # recording performance of the model
        nb_samples = steps * args.evalbatch_size
        fps = nb_samples / _time
        basic_info = [
            self.date, self.method, self.epoch, self.iter, nb_samples, _time,
            fps
        ]
        basic_info_names = [
            'date', 'method', 'epochs', 'iters', 'nb_samples', 'time(sec)',
            'fps'
        ]
        """
        metrics
        """
        perform = [round(idx / steps, 3) for idx in [psnr, nrmse, ssim]]
        perform_names = ['psnr', 'nrmse', 'ssim']
        cur_log = pd.DataFrame([basic_info + perform],
                               columns=basic_info_names + perform_names)
        # save performance
        if os.path.exists(
                os.path.join(Logs_DIR, 'statistic', "{}.csv".format(split))):
            logs = pd.read_csv(
                os.path.join(Logs_DIR, 'statistic', "{}.csv".format(split)))
        else:
            logs = pd.DataFrame([])
        logs = logs.append(cur_log, ignore_index=True)
        logs.to_csv(os.path.join(Logs_DIR, 'statistic',
                                 "{}.csv".format(split)),
                    index=False,
                    float_format='%.3f')
Пример #6
0
    def training(self, net, datasets, verbose=False):
        """
          input:
            net: (object) model & optimizer
            datasets : (list) [train, val] dataset object
        """
        args = self.args
        steps = len(datasets[0]) // args.batch_size

        if args.trigger == 'epoch':
            args.epochs = args.nEpochs
            args.iters = steps * args.nEpochs
            args.iter_interval = steps * args.interval
        else:
            args.iters = args.nEpochs
            args.epochs = args.nEpochs // steps + 1
            args.iter_interval = args.interval

        start = time.time()
        for epoch in range(1, args.epochs + 1):
            self.epoch = epoch
            # setup data loader
            data_loader = DataLoader(dataset=datasets[0],
                                     batch_size=args.batch_size,
                                     num_workers=args.threads,
                                     shuffle=False)
            batch_iterator = iter(data_loader)
            """
            metrics
            """
            epoch_loss, epoch_psnr = 0, 0
            for step in range(steps):
                self.iter += 1
                if self.iter > args.iters:
                    self.iter -= 1
                    break
                x, y = next(batch_iterator)
                x = x.to(self.device)
                y = y.to(self.device)
                # training
                gen_y = net(x)
                loss = F.mse_loss(gen_y, y)
                # Update generator parameters
                net.optimizer.zero_grad()
                loss.backward()
                net.optimizer.step()
                """
                metrics
                """
                epoch_loss += loss.item()
                epoch_psnr += metrics.psnr(gen_y.data, y.data)
                #                epoch_nrmse += metrics.nrmse(gen_y.data, y.data)
                #                epoch_ssim += metrics.ssim(gen_y.data, y.data)
                #                epoch_vifp += metrics.vifp(gen_y.data, y.data)
                if verbose:
                    print(
                        "===> Epoch[{}]({}/{}): Loss: {:.4f}; \t PSNR: {:.4f}".
                        format(epoch, step + 1, steps, loss.item(),
                               metrics.psnr(gen_y.data, y.data)))

                # logging
                if self.iter % args.iter_interval == 0:
                    _time = time.time() - start
                    nb_samples = args.iter_interval * args.batch_size
                    """
                    metrics
                    """
                    loss_log = loss.item()
                    psnr_log = metrics.psnr(gen_y.data, y.data)
                    nrmse_log = metrics.nrmse(gen_y.data, y.data)
                    ssim_log = metrics.ssim(gen_y.data, y.data)
                    #                    vifp_log = metrics.ssim(gen_y.data, y.data)
                    train_log = [
                        loss_log, psnr_log, nrmse_log, ssim_log, _time,
                        nb_samples / _time
                    ]
                    #                    train_log = [log_loss / args.iter_interval, log_psnr /
                    #                                 args.iter_interval, _time, nb_samples / _time]

                    self.train_log = [round(x, 3) for x in train_log]
                    self.validating(net, datasets[1])
                    self.logging(verbose=True)
                    if self.args.middle_checkpoint:
                        model_name_dir = "up{}_{}_{}_{}_{}".format(
                            self.args.upscale_factor, self.method,
                            self.args.trigger, self.args.nEpochs, self.date)
                        self.save_middle_checkpoint(net, self.epoch, self.iter,
                                                    model_name_dir)

                    # reinitialize
                    start = time.time()
#                    log_loss, log_psnr = 0, 0
            print(
                "===> Epoch {} Complete: Avg. Loss: {:.4f}; \t Avg. PSNR: {:.4f}"
                .format(epoch, epoch_loss / steps, epoch_psnr / steps))
            """
            metrics
            """
            epoch_loss, epoch_psnr = 0, 0
Пример #7
0
def main(config):
    # Device to use
    device = setup_device(config["gpus"])

    # Configure training objects
    # Generator
    model_name = config["model"]
    generator = get_generator_model(model_name)().to(device)
    weight_decay = config["L2_regularization_generator"]
    if config["use_illumination_predicter"]:
        light_latent_size = get_light_latent_size(model_name)
        illumination_predicter = IlluminationPredicter(
            in_size=light_latent_size).to(device)
        optimizerG = torch.optim.Adam(
            list(generator.parameters()) +
            list(illumination_predicter.parameters()),
            weight_decay=weight_decay)
    else:
        optimizerG = torch.optim.Adam(generator.parameters(),
                                      weight_decay=weight_decay)
    # Discriminator
    if config["use_discriminator"]:
        if config["discriminator_everything_as_input"]:
            raise NotImplementedError  # TODO
        else:
            discriminator = NLayerDiscriminator().to(device)
        optimizerD = torch.optim.Adam(
            discriminator.parameters(),
            weight_decay=config["L2_regularization_discriminator"])

    # Losses
    reconstruction_loss = ReconstructionLoss().to(device)
    if config["use_illumination_predicter"]:
        color_prediction_loss = ColorPredictionLoss().to(device)
        direction_prediction_loss = DirectionPredictionLoss().to(device)
    if config["use_discriminator"]:
        gan_loss = GANLoss().to(device)
        fool_gan_loss = FoolGANLoss().to(device)

    # Metrics
    if "scene_latent" in config["metrics"]:
        scene_latent_loss = SceneLatentLoss().to(device)
    if "light_latent" in config["metrics"]:
        light_latent_loss = LightLatentLoss().to(device)
    if "LPIPS" in config["metrics"]:
        lpips_loss = LPIPS(
            net_type=
            'alex',  # choose a network type from ['alex', 'squeeze', 'vgg']
            version='0.1'  # Currently, v0.1 is supported
        ).to(device)

    # Configure dataloader
    size = config['image_resize']
    # train
    try:
        file = open(
            'traindataset' + str(config['overfit_test']) + str(size) +
            '.pickle', 'rb')
        print("Restoring train dataset from pickle file")
        train_dataset = pickle.load(file)
        file.close()
        print("Restored train dataset from pickle file")
    except:
        train_dataset = InputTargetGroundtruthDataset(
            transform=transforms.Resize(size),
            data_path=TRAIN_DATA_PATH,
            locations=['scene_abandonned_city_54']
            if config['overfit_test'] else None,
            input_directions=["S", "E"] if config['overfit_test'] else None,
            target_directions=["S", "E"] if config['overfit_test'] else None,
            input_colors=["2500", "6500"] if config['overfit_test'] else None,
            target_colors=["2500", "6500"] if config['overfit_test'] else None)
        file = open(
            "traindataset" + str(config['overfit_test']) + str(size) +
            '.pickle', 'wb')
        pickle.dump(train_dataset, file)
        file.close()
        print("saved traindataset" + str(config['overfit_test']) + str(size) +
              '.pickle')
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config['train_batch_size'],
                                  shuffle=config['shuffle_data'],
                                  num_workers=config['train_num_workers'])
    # test
    try:
        file = open(
            "testdataset" + str(config['overfit_test']) + str(size) +
            '.pickle', 'rb')
        print("Restoring full test dataset from pickle file")
        test_dataset = pickle.load(file)
        file.close()
        print("Restored full test dataset from pickle file")
    except:
        test_dataset = InputTargetGroundtruthDataset(
            transform=transforms.Resize(size),
            data_path=VALIDATION_DATA_PATH,
            locations=["scene_city_24"] if config['overfit_test'] else None,
            input_directions=["S", "E"] if config['overfit_test'] else None,
            target_directions=["S", "E"] if config['overfit_test'] else None,
            input_colors=["2500", "6500"] if config['overfit_test'] else None,
            target_colors=["2500", "6500"] if config['overfit_test'] else None)
        file = open(
            "testdataset" + str(config['overfit_test']) + str(size) +
            '.pickle', 'wb')
        pickle.dump(test_dataset, file)
        file.close()
        print("saved testdataset" + str(config['overfit_test']) + str(size) +
              '.pickle')
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=config['test_batch_size'],
                                 shuffle=config['shuffle_data'],
                                 num_workers=config['test_num_workers'])
    test_dataloaders = {"full": test_dataloader}
    if config["testing_on_subsets"]:
        additional_pairing_strategies = [[SameLightColor()],
                                         [SameLightDirection()]]
        #[SameScene()],
        #[SameScene(), SameLightColor()],
        #[SameScene(), SameLightDirection()],
        #[SameLightDirection(), SameLightColor()],
        #[SameScene(), SameLightDirection(), SameLightColor()]]
        for pairing_strategies in additional_pairing_strategies:
            try:
                file = open(
                    "testdataset" + str(config['overfit_test']) + str(size) +
                    str(pairing_strategies) + '.pickle', 'rb')
                print("Restoring test dataset " + str(pairing_strategies) +
                      " from pickle file")
                test_dataset = pickle.load(file)
                file.close()
                print("Restored test dataset " + str(pairing_strategies) +
                      " from pickle file")
            except:
                test_dataset = InputTargetGroundtruthDataset(
                    transform=transforms.Resize(size),
                    data_path=VALIDATION_DATA_PATH,
                    pairing_strategies=pairing_strategies,
                    locations=["scene_city_24"]
                    if config['overfit_test'] else None,
                    input_directions=["S", "E"]
                    if config['overfit_test'] else None,
                    target_directions=["S", "E"]
                    if config['overfit_test'] else None,
                    input_colors=["2500", "6500"]
                    if config['overfit_test'] else None,
                    target_colors=["2500", "6500"]
                    if config['overfit_test'] else None)
                file = open(
                    "testdataset" + str(config['overfit_test']) + str(size) +
                    str(pairing_strategies) + '.pickle', 'wb')
                pickle.dump(test_dataset, file)
                file.close()
                print("saved testdataset" + str(config['overfit_test']) +
                      str(size) + str(pairing_strategies) + '.pickle')
            test_dataloader = DataLoader(
                test_dataset,
                batch_size=config['test_batch_size'],
                shuffle=config['shuffle_data'],
                num_workers=config['test_num_workers'])
            test_dataloaders[str(pairing_strategies)] = test_dataloader
    print(
        f'Dataset contains {len(train_dataset)} train samples and {len(test_dataset)} test samples.'
    )
    print(
        f'{config["shown_samples_grid"]} samples will be visualized every {config["testing_frequence"]} batches.'
    )
    print(
        f'Evaluation will be made every {config["testing_frequence"]} batches on {config["batches_for_testing"]} batches'
    )

    # Configure tensorboard
    writer = tensorboard.setup_summary_writer(config['name'])
    tensorboard_process = tensorboard.start_tensorboard_process(
    )  # TODO: config["tensorboard_port"]

    # Train loop

    # Init train scalars
    (train_generator_loss, train_discriminator_loss, train_score, train_lpips,
     train_ssim, train_psnr, train_scene_latent_loss_input_gt,
     train_scene_latent_loss_input_target, train_scene_latent_loss_gt_target,
     train_light_latent_loss_input_gt, train_light_latent_loss_input_target,
     train_light_latent_loss_gt_target, train_color_prediction_loss,
     train_direction_prediction_loss) = (0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                         0., 0., 0., 0., 0.)

    # Init train loop
    train_dataloader_iter = iter(train_dataloader)
    train_batches_counter = 0
    print(f'Running for {config["train_duration"]} batches.')

    last_save_t = 0
    # Train loop
    while train_batches_counter < config['train_duration']:
        # Store trained model
        t = time.time()
        if t - last_save_t > config["checkpoint_period"]:
            last_save_t = t
            save_trained(generator, "generator" + config['name'] + str(t))
            if config["use_illumination_predicter"]:
                save_trained(
                    illumination_predicter,
                    "illumination_predicter" + config['name'] + str(t))
            if config["use_discriminator"]:
                save_trained(discriminator,
                             "discriminator" + config['name'] + str(t))

        #with torch.autograd.detect_anomaly():
        # Load batch
        if config["debug"]: print('Load batch', get_gpu_memory_map())
        with torch.no_grad():
            train_batch, train_dataloader_iter = next_batch(
                train_dataloader_iter, train_dataloader)
            (input_image, target_image, groundtruth_image, input_color,
             target_color, groundtruth_color, input_direction,
             target_direction,
             groundtruth_direction) = extract_from_batch(train_batch, device)

        # Generator
        # Generator: Forward
        if config["debug"]:
            print('Generator: Forward', get_gpu_memory_map())
        output = generator(input_image, target_image, groundtruth_image)
        (relit_image, input_light_latent, target_light_latent,
         groundtruth_light_latent, input_scene_latent, target_scene_latent,
         groundtruth_scene_latent) = output
        r = reconstruction_loss(relit_image, groundtruth_image)
        generator_loss = config['generator_loss_reconstruction_l2_factor'] * r
        if config["use_illumination_predicter"]:
            input_illumination = illumination_predicter(input_light_latent)
            target_illumination = illumination_predicter(target_light_latent)
            groundtruth_illumination = illumination_predicter(
                groundtruth_light_latent)
            c = (1 / 3 *
                 color_prediction_loss(input_illumination[:, 0], input_color) +
                 1 / 3 *
                 color_prediction_loss(target_illumination[:, 0], target_color)
                 + 1 / 3 * color_prediction_loss(
                     groundtruth_illumination[:, 0], groundtruth_color))
            d = (1 / 3 * direction_prediction_loss(input_illumination[:, 1],
                                                   input_direction) +
                 1 / 3 * direction_prediction_loss(target_illumination[:, 1],
                                                   target_direction) +
                 1 / 3 * direction_prediction_loss(
                     groundtruth_illumination[:, 1], groundtruth_direction))
            generator_loss += config['generator_loss_color_l2_factor'] * c
            generator_loss += config['generator_loss_direction_l2_factor'] * d
            train_color_prediction_loss += c.item()
            train_direction_prediction_loss += d.item()
        train_generator_loss += generator_loss.item()
        train_score += reconstruction_loss(
            input_image, groundtruth_image).item() / reconstruction_loss(
                relit_image, groundtruth_image).item()
        if "scene_latent" in config["metrics"]:
            train_scene_latent_loss_input_gt += scene_latent_loss(
                input_image, groundtruth_image).item()
            train_scene_latent_loss_input_target += scene_latent_loss(
                input_image, target_image).item()
            train_scene_latent_loss_gt_target += scene_latent_loss(
                target_image, groundtruth_image).item()
        if "light_latent" in config["metrics"]:
            train_light_latent_loss_input_gt += light_latent_loss(
                input_image, groundtruth_image).item()
            train_light_latent_loss_input_target += light_latent_loss(
                input_image, target_image).item()
            train_light_latent_loss_gt_target += light_latent_loss(
                target_image, groundtruth_image).item()
        if "LPIPS" in config["metrics"]:
            train_lpips += lpips_loss(relit_image, groundtruth_image).item()
        if "SSIM" in config["metrics"]:
            train_ssim += ssim(relit_image, groundtruth_image).item()
        if "PSNR" in config["metrics"]:
            train_psnr += psnr(relit_image, groundtruth_image).item()

        # Generator: Backward
        if config["debug"]:
            print('Generator: Backward', get_gpu_memory_map())
        optimizerG.zero_grad()
        if config["use_discriminator"]:
            optimizerD.zero_grad()
        if config["use_discriminator"]:
            discriminator.zero_grad()
        generator_loss.backward(
        )  # use requires_grad = False for speed? Et pour enlever ces zero_grad en double!
        # Generator: Update parameters
        if config["debug"]:
            print('Generator: Update parameters', get_gpu_memory_map())
        optimizerG.step()

        # Discriminator
        if config["use_discriminator"]:
            if config["debug"]:
                print('Discriminator', get_gpu_memory_map())
            # Discriminator : Forward
            output = generator(input_image, target_image, groundtruth_image)
            (relit_image, input_light_latent, target_light_latent,
             groundtruth_light_latent, input_scene_latent, target_scene_latent,
             groundtruth_scene_latent) = output
            disc_out_fake = discriminator(relit_image)
            disc_out_real = discriminator(groundtruth_image)
            discriminator_loss = config[
                'discriminator_loss_gan_factor'] * gan_loss(
                    disc_out_fake, disc_out_real)
            train_discriminator_loss += discriminator_loss.item()
            # Discriminator : Backward
            optimizerD.zero_grad()
            discriminator.zero_grad()
            optimizerG.zero_grad()
            generator.zero_grad()
            discriminator_loss.backward()
            generator.zero_grad()
            optimizerG.zero_grad()
            # Discriminator : Update parameters
            optimizerD.step()

        # Update train_batches_counter
        train_batches_counter += 1

        # If it is time to do so, test and visualize current progress
        step, modulo = divmod(train_batches_counter,
                              config['testing_frequence'])
        if modulo == 0:
            with torch.no_grad():

                # Visualize train
                if config["debug"]:
                    print('Visualize train', get_gpu_memory_map())
                write_images(
                    writer=writer,
                    header="Train",
                    step=step,
                    inputs=input_image[:config['shown_samples_grid']],
                    input_light_latents=input_light_latent[:config[
                        'shown_samples_grid']],
                    targets=target_image[:config['shown_samples_grid']],
                    target_light_latents=target_light_latent[:config[
                        'shown_samples_grid']],
                    groundtruths=groundtruth_image[:config[
                        'shown_samples_grid']],
                    groundtruth_light_latents=groundtruth_light_latent[:config[
                        'shown_samples_grid']],
                    relits=relit_image[:config['shown_samples_grid']])
                write_measures(
                    writer=writer,
                    header="Train",
                    step=step,
                    generator_loss=train_generator_loss /
                    config['testing_frequence'],
                    discriminator_loss=train_discriminator_loss /
                    config['testing_frequence'],
                    score=train_score / config['testing_frequence'],
                    ssim=train_ssim / config['testing_frequence'],
                    lpips=train_lpips / config['testing_frequence'],
                    psnr=train_psnr / config['testing_frequence'],
                    scene_input_gt=train_scene_latent_loss_input_gt /
                    config['testing_frequence'],
                    scene_input_target=train_scene_latent_loss_input_target /
                    config['testing_frequence'],
                    scene_gt_target=train_scene_latent_loss_gt_target /
                    config['testing_frequence'],
                    light_input_gt=train_light_latent_loss_input_gt /
                    config['testing_frequence'],
                    light_input_target=train_light_latent_loss_input_target /
                    config['testing_frequence'],
                    light_gt_target=train_light_latent_loss_gt_target /
                    config['testing_frequence'],
                    color_prediction=train_color_prediction_loss /
                    config['testing_frequence'],
                    direction_prediction=train_direction_prediction_loss /
                    config['testing_frequence'])
                print('Train', 'Loss:',
                      train_generator_loss / config['testing_frequence'],
                      'Score:', train_score / config['testing_frequence'])
                if config["debug_memory"]:
                    print(get_gpu_memory_map())
                    # del generator_loss
                    # torch.cuda.empty_cache()
                    # print(get_gpu_memory_map())

                # Reset train scalars
                if config["debug"]:
                    print('Reset train scalars', get_gpu_memory_map())
                (train_generator_loss, train_discriminator_loss, train_score,
                 train_lpips, train_ssim, train_psnr,
                 train_scene_latent_loss_input_gt,
                 train_scene_latent_loss_input_target,
                 train_scene_latent_loss_gt_target,
                 train_light_latent_loss_input_gt,
                 train_light_latent_loss_input_target,
                 train_light_latent_loss_gt_target,
                 train_color_prediction_loss,
                 train_direction_prediction_loss) = (0., 0., 0., 0., 0., 0.,
                                                     0., 0., 0., 0., 0., 0.,
                                                     0., 0.)

                # Test loop

                if config["debug"]: print('Test loop', get_gpu_memory_map())
                for header, test_dataloader in test_dataloaders.items():

                    # Init test scalars
                    if config["debug"]:
                        print('Init test scalars', get_gpu_memory_map())
                    (test_generator_loss, test_discriminator_loss, test_score,
                     test_lpips, test_ssim, test_psnr,
                     test_scene_latent_loss_input_gt,
                     test_scene_latent_loss_input_target,
                     test_scene_latent_loss_gt_target,
                     test_light_latent_loss_input_gt,
                     test_light_latent_loss_input_target,
                     test_light_latent_loss_gt_target,
                     test_color_prediction_loss,
                     test_direction_prediction_loss) = (0., 0., 0., 0., 0., 0.,
                                                        0., 0., 0., 0., 0., 0.,
                                                        0., 0.)

                    # Init test loop
                    if config["debug"]:
                        print('Init test loop', get_gpu_memory_map())
                    test_dataloader_iter = iter(test_dataloader)
                    testing_batches_counter = 0

                    while testing_batches_counter < config[
                            'batches_for_testing']:

                        # Load batch
                        if config["debug"]:
                            print('Load batch', get_gpu_memory_map())
                        test_batch, test_dataloader_iter = next_batch(
                            test_dataloader_iter, test_dataloader)
                        (input_image, target_image, groundtruth_image,
                         input_color, target_color, groundtruth_color,
                         input_direction, target_direction,
                         groundtruth_direction) = extract_from_batch(
                             test_batch, device)

                        # Forward

                        # Generator
                        if config["debug"]:
                            print('Generator', get_gpu_memory_map())
                        output = generator(input_image, target_image,
                                           groundtruth_image)
                        (relit_image, input_light_latent, target_light_latent,
                         groundtruth_light_latent, input_scene_latent,
                         target_scene_latent,
                         groundtruth_scene_latent) = output
                        r = reconstruction_loss(relit_image, groundtruth_image)
                        generator_loss = config[
                            'generator_loss_reconstruction_l2_factor'] * r
                        if config["use_illumination_predicter"]:
                            input_illumination = illumination_predicter(
                                input_light_latent)
                            target_illumination = illumination_predicter(
                                target_light_latent)
                            groundtruth_illumination = illumination_predicter(
                                groundtruth_light_latent)
                            c = (1 / 3 * color_prediction_loss(
                                input_illumination[:, 0], input_color) +
                                 1 / 3 * color_prediction_loss(
                                     target_illumination[:, 0], target_color) +
                                 1 / 3 * color_prediction_loss(
                                     groundtruth_illumination[:, 0],
                                     groundtruth_color))
                            d = (1 / 3 * direction_prediction_loss(
                                input_illumination[:, 1], input_direction) +
                                 1 / 3 * direction_prediction_loss(
                                     target_illumination[:, 1],
                                     target_direction) +
                                 1 / 3 * direction_prediction_loss(
                                     groundtruth_illumination[:, 1],
                                     groundtruth_direction))
                            generator_loss += config[
                                'generator_loss_color_l2_factor'] * c
                            generator_loss += config[
                                'generator_loss_direction_l2_factor'] * d
                            test_color_prediction_loss += c.item()
                            test_direction_prediction_loss += d.item()
                        test_generator_loss += generator_loss.item()
                        test_score += reconstruction_loss(
                            input_image,
                            groundtruth_image).item() / reconstruction_loss(
                                relit_image, groundtruth_image).item()
                        if "scene_latent" in config["metrics"]:
                            test_scene_latent_loss_input_gt += scene_latent_loss(
                                input_image, groundtruth_image).item()
                            test_scene_latent_loss_input_target += scene_latent_loss(
                                input_image, target_image).item()
                            test_scene_latent_loss_gt_target += scene_latent_loss(
                                target_image, groundtruth_image).item()
                        if "light_latent" in config["metrics"]:
                            test_light_latent_loss_input_gt += light_latent_loss(
                                input_image, groundtruth_image).item()
                            test_light_latent_loss_input_target += light_latent_loss(
                                input_image, target_image).item()
                            test_light_latent_loss_gt_target += light_latent_loss(
                                target_image, groundtruth_image).item()
                        if "LPIPS" in config["metrics"]:
                            test_lpips += lpips_loss(relit_image,
                                                     groundtruth_image).item()
                        if "SSIM" in config["metrics"]:
                            test_ssim += ssim(relit_image,
                                              groundtruth_image).item()
                        if "PSNR" in config["metrics"]:
                            test_psnr += psnr(relit_image,
                                              groundtruth_image).item()

                        # Discriminator
                        if config["debug"]:
                            print('Discriminator', get_gpu_memory_map())
                        if config["use_discriminator"]:
                            disc_out_fake = discriminator(relit_image)
                            disc_out_real = discriminator(groundtruth_image)
                            discriminator_loss = config[
                                'discriminator_loss_gan_factor'] * gan_loss(
                                    disc_out_fake, disc_out_real)
                            test_discriminator_loss += discriminator_loss.item(
                            )

                        # Update testing_batches_counter
                        if config["debug"]:
                            print('Update testing_batches_counter',
                                  get_gpu_memory_map())
                        testing_batches_counter += 1

                    # Visualize test
                    if config["debug"]:
                        print('Visualize test', get_gpu_memory_map())
                    write_images(
                        writer=writer,
                        header="Test-" + header,
                        step=step,
                        inputs=input_image[:config['shown_samples_grid']],
                        input_light_latents=input_light_latent[:config[
                            'shown_samples_grid']],
                        targets=target_image[:config['shown_samples_grid']],
                        target_light_latents=target_light_latent[:config[
                            'shown_samples_grid']],
                        groundtruths=groundtruth_image[:config[
                            'shown_samples_grid']],
                        groundtruth_light_latents=
                        groundtruth_light_latent[:
                                                 config['shown_samples_grid']],
                        relits=relit_image[:config['shown_samples_grid']])
                    write_measures(
                        writer=writer,
                        header="Test-" + header,
                        step=step,
                        generator_loss=test_generator_loss /
                        config['batches_for_testing'],
                        discriminator_loss=test_discriminator_loss /
                        config['batches_for_testing'],
                        score=test_score / config['batches_for_testing'],
                        ssim=test_ssim / config['batches_for_testing'],
                        lpips=test_lpips / config['batches_for_testing'],
                        psnr=test_psnr / config['batches_for_testing'],
                        scene_input_gt=test_scene_latent_loss_input_gt /
                        config['batches_for_testing'],
                        scene_input_target=test_scene_latent_loss_input_target
                        / config['batches_for_testing'],
                        scene_gt_target=test_scene_latent_loss_gt_target /
                        config['batches_for_testing'],
                        light_input_gt=test_light_latent_loss_input_gt /
                        config['batches_for_testing'],
                        light_input_target=test_light_latent_loss_input_target
                        / config['batches_for_testing'],
                        light_gt_target=test_light_latent_loss_gt_target /
                        config['batches_for_testing'],
                        color_prediction=test_color_prediction_loss /
                        config['batches_for_testing'],
                        direction_prediction=test_direction_prediction_loss /
                        config['batches_for_testing'])
                    print('Test-' + header, 'Loss:',
                          test_generator_loss / config['testing_frequence'],
                          'Score:', test_score / config['testing_frequence'])

                    if config["debug_memory"]:
                        print(get_gpu_memory_map())
Пример #8
0
args = parser.parse_args()

if __name__ == "__main__":
    pred_dir = args.preds
    targets_dir = args.targets
    pred_names = os.listdir(pred_dir)
    targets_names = os.listdir(targets_dir)
    if pred_names != targets_names:
        print(
            """There are inconsistent filenames in the preds and targets directories
        The script assumes both sets have the same name for corrosponding images
        Please remove any extra or erroneous files in either set, make sure that the same encoding is used in both sets of images, etc."""
        )
    else:
        preds = [
            Image.open(os.path.join(os.getcwd(), name)) for name in pred_names
        ]
        targets = [
            Image.open(os.path.join(os.getcwd(), name))
            for name in targets_names
        ]
        psnr_score = sum(
            [psnr(pred, target)
             for pred, target in tqdm(zip(preds, targets))]) / len(preds)
        print(f"PSNR score: {psnr_score}")
        ssim_score = sum(
            [ssim(pred, target)
             for pred, target in tqdm(zip(preds, targets))]) / len(preds)
        print(f"SSIM score: {ssim_score}")
Пример #9
0
def metrics_eval(model, test_loader, logging_step, writer, args):

    print("Metric evaluation on {}...".format(args.testset))

    # storing metrics
    # ssim_yhat = []
    ssim_mu0 = []
    ssim_mu05 = []
    ssim_mu08 = []
    ssim_mu1 = []
    # psnr_yhat = []
    psnr_0 = []
    psnr_05 = []
    psnr_08 = []
    psnr_1 = []

    model.eval()
    with torch.no_grad():
        for batch_idx, item in enumerate(test_loader):

            y = item[0]
            x = item[1]
            orig_shape = item[2]
            w, h = orig_shape

            # Push tensors to GPU
            y = y.to("cuda")
            x = x.to("cuda")

            if args.modeltype == "flow":
                mu0 = model._sample(x=x, eps=0)
                mu05 = model._sample(x=x, eps=0.5)
                mu08 = model._sample(x=x, eps=0.8)
                mu1 = model._sample(x=x, eps=1)

                ssim_mu0.append(metrics.ssim(y, mu0, orig_shape))
                ssim_mu05.append(metrics.ssim(y, mu05, orig_shape))
                ssim_mu08.append(metrics.ssim(y, mu08, orig_shape))
                ssim_mu1.append(metrics.ssim(y, mu1, orig_shape))

                psnr_0.append(metrics.psnr(y, mu0, orig_shape))
                psnr_05.append(metrics.psnr(y, mu05, orig_shape))
                psnr_08.append(metrics.psnr(y, mu08, orig_shape))
                psnr_1.append(metrics.psnr(y, mu1, orig_shape))

            elif args.modeltype == "dlogistic":
                # sample from model
                sample, means = model._sample(x=x)
                ssim_mu0.append(metrics.ssim(y, means, orig_shape))
                psnr_0.append(metrics.psnr(y, means, orig_shape))

                # ---------------------- Visualize Samples-------------
                if args.visual:
                    # only for testing, delete snippet later
                    torchvision.utils.save_image(x[:, :, :h, :w],
                                                 "x.png",
                                                 nrow=1,
                                                 padding=2,
                                                 normalize=False)
                    torchvision.utils.save_image(y[:, :, :h, :w],
                                                 "y.png",
                                                 nrow=1,
                                                 padding=2,
                                                 normalize=False)
                    torchvision.utils.save_image(
                        means[:, :, :h, :w],
                        "dlog_mu.png",
                        nrow=1,
                        padding=2,
                        normalize=False,
                    )
                    torchvision.utils.save_image(
                        sample[:, :, :h, :w],
                        "dlog_sample.png",
                        nrow=1,
                        padding=2,
                        normalize=False,
                    )

        writer.add_scalar("ssim_std0", np.mean(ssim_mu0), logging_step)
        writer.add_scalar("psnr0", np.mean(psnr_0), logging_step)

        if args.modeltype == "flow":
            writer.add_scalar("ssim_std05", np.mean(ssim_mu05), logging_step)
            writer.add_scalar("ssim_std08", np.mean(ssim_mu08), logging_step)
            writer.add_scalar("ssim_std1", np.mean(ssim_mu1), logging_step)
            writer.add_scalar("psnr05", np.mean(psnr_05), logging_step)
            writer.add_scalar("psnr08", np.mean(psnr_08), logging_step)
            writer.add_scalar("psnr1", np.mean(psnr_1), logging_step)

        print("PSNR (GT,mean):", np.mean(psnr_0))
        print("SSIM (GT,mean):", np.mean(ssim_mu0))

        return writer