Beispiel #1
0
 def __init__(self, img_shape, SRscale):
     self.img_shape = img_shape
     self.SRscale = SRscale
     self.target_shape = (SRscale*img_shape[0], SRscale*img_shape[1], img_shape[2])
     
     # Configure data loader
     self.data_loader = DataLoader(img_res=(img_shape[0], img_shape[1]), SRscale=SRscale)
     
     #instantiate the models
     self.SR = SR(self.img_shape)
     self.D = D2(self.target_shape)
     
     #compile discriminator
     self.D.compile(loss='mse', loss_weights = [1], optimizer = Adam(0.0002))
     print(self.D.summary())
     
     #create generator graph
     y = Input(shape = self.img_shape)
     
     fake_Y = self.SR(y)
     valid_Y = self.D(fake_Y)
     
     self.D.trainable = False
     self.Generator = Model(inputs = y, outputs = [valid_Y, fake_Y])
     self.Generator.compile(loss = ['mse', 'mse'], loss_weights = [1, 1], optimizer = Adam(0.0002))
     print(self.Generator.summary())
Beispiel #2
0
    def enhance(self, img_path, model_name, reference=True):
        phone_image = self.data_loader.load_img(img_path)  #load image

        phone_image = phone_image[0]
        #phone_image = phone_image[400:1400, 600:1600, :]
        img_shape = phone_image.shape  #get dimensions to build the suitable model

        generator_model = SR(scale=self.SRscale,
                             input_shape=img_shape,
                             n_feats=128,
                             n_resblocks=16,
                             name="Test Super-resolver")

        self.model_name = model_name
        self.model = generator_model
        self.model.load_weights("models/%s" % (model_name))

        start = time.time()
        fake_dslr_image = self.model.predict(
            np.expand_dims(phone_image, axis=0))
        end = time.time()
        print("TIME TAKEN: ", end - start)
        print(np.amax(fake_dslr_image))
        print(np.amin(fake_dslr_image))

        if reference:
            fig, axs = plt.subplots(1, 2)
            ax = axs[0]
            bi_cubic_upscaling = cv.resize(
                phone_image,
                (img_shape[0] * self.SRscale, img_shape[1] * self.SRscale),
                interpolation=cv.INTER_CUBIC)
            bi_cubic_upscaling = np.clip(bi_cubic_upscaling, 0, 1)
            ax.imshow(bi_cubic_upscaling)
            ax.set_title("2x Bi-cubic upscaling")

            ax = axs[1]
            ax.imshow(np.clip(fake_dslr_image[0], 0, 1))
            ax.set_title("2x SR upscaling")
            plt.show()
Beispiel #3
0
class EDSR():
    def __init__(self, img_shape, SRscale):
        self.img_shape = img_shape
        self.SRscale = SRscale
        self.target_shape = (SRscale*img_shape[0], SRscale*img_shape[1], img_shape[2])
        
        # Configure data loader
        self.data_loader = DataLoader(img_res=(img_shape[0], img_shape[1]), SRscale=SRscale)
        
        #instantiate the models
        self.SR = SR(self.img_shape)
        self.D = D2(self.target_shape)
        
        #compile discriminator
        self.D.compile(loss='mse', loss_weights = [1], optimizer = Adam(0.0002))
        print(self.D.summary())
        
        #create generator graph
        y = Input(shape = self.img_shape)
        
        fake_Y = self.SR(y)
        valid_Y = self.D(fake_Y)
        
        self.D.trainable = False
        self.Generator = Model(inputs = y, outputs = [valid_Y, fake_Y])
        self.Generator.compile(loss = ['mse', 'mse'], loss_weights = [1, 1], optimizer = Adam(0.0002))
        print(self.Generator.summary())
        
    def train(self, epochs, batch_size=10):
        
        #create adversarial ground truths
        out_shape = (batch_size,) + (self.target_shape[0], self.target_shape[1], 1)
        valid_D = np.ones(out_shape)
        fake_D = np.zeros(out_shape)
        
        #define an evaluator object to monitor the progress of the training
        dynamic_evaluator = evaluator(img_res=self.img_shape, SRscale = self.SRscale)
        for epoch in range(epochs):
            for batch, (_, img_y, img_Y) in enumerate(self.data_loader.load_batch(batch_size)):
                #translate domain y to domain Y
                fake_img_Y = self.SR.predict(img_y)
                
                #train Discriminator
                D_loss_real = self.D.train_on_batch(img_Y, valid_D)
                D_loss_fake = self.D.train_on_batch(fake_img_Y, fake_D)
                D_loss = 0.5 * np.add(D_loss_real, D_loss_fake)
                
                #train Generator
                G_loss = self.Generator.train_on_batch([img_y], [valid_D, img_Y])
                
                print("[Epoch %d/%d] [Batch %d/%d]--[D: %.3f] -- [G_adv: %.3f] [G_rec: %.3f]" % (epoch, epochs,
                      batch, self.data_loader.n_batches, D_loss, G_loss[1], G_loss[2]))
                
                if batch % 25 == 0 and batch!=0:
                    """save the model"""
                    model_name="{}_{}.h5".format(epoch, batch)
                    self.SR.save("pretrained models/"+model_name)
                    print("Epoch: {} --- Batch: {} ---- saved".format(epoch, batch))
                    
                    dynamic_evaluator.model = self.SR
                    dynamic_evaluator.epoch = epoch
                    dynamic_evaluator.batch = batch
                    dynamic_evaluator.perceptual_test(5)
                    
                    sample_mean_ssim = dynamic_evaluator.objective_test(batch_size=250)
                    print("Sample mean SSIM: -------------------  %05f   -------------------" % (sample_mean_ssim))
Beispiel #4
0
    def __init__(self, img_shape, SRscale=2):
        # Input shape
        self.SRscale=SRscale
        self.img_shape = img_shape
        self.target_res = (self.SRscale*self.img_shape[0], self.SRscale*self.img_shape[1], self.img_shape[2])
        
        # Configure data loader
        self.data_loader = DataLoader(img_res=(img_shape[0], img_shape[1]), SRscale=SRscale)
        
        
        #Training will take place in 2 steps. 
        
        #In the first step a combined which contains all the generators will be updated
        #In the second step all the discriminators will be updated
        
        #For this reason we define a combined model which contains all the generator mapping
        #The discriminators are defined seperately
        
        self.G1 = G1(img_shape)
        self.D1 = D1(img_shape)
        self.G2 = G2(img_shape)
        self.blur = blur(img_shape)
        self.SR = SR(img_shape)
        self.SR.load_weights('pretrained models/128_8.h5')
        self.D2 = D2(self.target_res)
        self.G3 = G3(self.target_res)


        #compile the discriminators
        optimizer = Adam(0.0002)
        self.D1.compile(loss='mse', loss_weights=[1], optimizer=optimizer, metrics=['accuracy'])
        #print(self.D1.summary())
        self.D2.compile(loss='mse', loss_weights=[1], optimizer=optimizer, metrics=['accuracy'])
        
        #-------------------------
        # Construct Computational
        #   Graph of Generators
        #    (combined model)
        #-------------------------

        self.cyclic1 = self.combined_model(cycle=1)
        print(self.cyclic1.summary())
        self.cyclic2 = self.combined_model(cycle=2)
        print(self.cyclic2.summary())
        
        #logger settings
        self.training = []
        
        self.D1_loss = []
        self.D2_loss = []
        self.G1_adv = []
        self.SR_adv = []
        
        self.cyc1 = []
        self.blur1 = []
        self.tv1 = []
        
        self.cyc2 = []
        self.tv2 = []
        
        self.ssim_eval_time = []
        self.ssim = []
Beispiel #5
0
class CINGAN():
    def __init__(self, img_shape, SRscale=2):
        # Input shape
        self.SRscale=SRscale
        self.img_shape = img_shape
        self.target_res = (self.SRscale*self.img_shape[0], self.SRscale*self.img_shape[1], self.img_shape[2])
        
        # Configure data loader
        self.data_loader = DataLoader(img_res=(img_shape[0], img_shape[1]), SRscale=SRscale)
        
        
        #Training will take place in 2 steps. 
        
        #In the first step a combined which contains all the generators will be updated
        #In the second step all the discriminators will be updated
        
        #For this reason we define a combined model which contains all the generator mapping
        #The discriminators are defined seperately
        
        self.G1 = G1(img_shape)
        self.D1 = D1(img_shape)
        self.G2 = G2(img_shape)
        self.blur = blur(img_shape)
        self.SR = SR(img_shape)
        self.SR.load_weights('pretrained models/128_8.h5')
        self.D2 = D2(self.target_res)
        self.G3 = G3(self.target_res)


        #compile the discriminators
        optimizer = Adam(0.0002)
        self.D1.compile(loss='mse', loss_weights=[1], optimizer=optimizer, metrics=['accuracy'])
        #print(self.D1.summary())
        self.D2.compile(loss='mse', loss_weights=[1], optimizer=optimizer, metrics=['accuracy'])
        
        #-------------------------
        # Construct Computational
        #   Graph of Generators
        #    (combined model)
        #-------------------------

        self.cyclic1 = self.combined_model(cycle=1)
        print(self.cyclic1.summary())
        self.cyclic2 = self.combined_model(cycle=2)
        print(self.cyclic2.summary())
        
        #logger settings
        self.training = []
        
        self.D1_loss = []
        self.D2_loss = []
        self.G1_adv = []
        self.SR_adv = []
        
        self.cyc1 = []
        self.blur1 = []
        self.tv1 = []
        
        self.cyc2 = []
        self.tv2 = []
        
        self.ssim_eval_time = []
        self.ssim = []
    
    def combined_model(self, cycle):
        if cycle == 1:
            x = Input(shape = self.img_shape)
            
            #-------------denoising network---------------
            # Translate images to the other domain
            fake_y = self.G1(x) #L_tv(1)
            
            #cycle-consinstent image
            cyc_x = self.G2(fake_y) #L_cyc(1)
            
            #pass fake_y through the blurring kernel
            blur_fake_y = self.blur(fake_y) #L_blur(1)
            
            #pass fake_y through discriminator D1
            valid_y = self.D1(fake_y) #L_GAN(1)
            
            #freeze the discriminator
            self.D1.trainable = False
            
            #Denoising network paramaeters
            w1=10 #relative importance cycle constintency 
            w2=20 #relative importance of conservation of color distribution
            w3=1 #relative importance of total variation
            
            model = Model(inputs = [x], outputs = [valid_y, cyc_x, blur_fake_y, fake_y], name = "Cyclic 1")
            model.compile(loss=['mse', 'mse', 'mse', total_variation], loss_weights=[1, w1, w2, w3], optimizer=Adam(0.0002))
            
            return model
        
        elif cycle == 2:
            
            fake_y = Input(self.img_shape)
            
            #------------------SR network----------------------
            SR_fake_Y = self.SR(fake_y) #L_tv(2)
            valid_Y = self.D2(SR_fake_Y) #L_GAN(2)
            cyc_x_2 = self.G3(SR_fake_Y) #L_cyc(2)
            
            #1st SR network parameters
            l1 = 10 #relative importance cycle consistency in 1st cycle
            l2 = 1 #relative importance of total variation
            
            #freeze the discriminator
            self.D2.trainable = False
            
            model = Model(inputs = [fake_y], outputs = [valid_Y, cyc_x_2, SR_fake_Y], name = "cyclic 2")
            model.compile(loss = ['mse', 'mse', total_variation], loss_weights = [1, l1, l2], optimizer=Adam(0.1*0.0002))
            
            return model
            
    def log(self,):
        fig, axs = plt.subplots(2, 3)
        
        ax = axs[0,0]
        ax.plot(self.training, self.D1_loss, label="D1 adv loss")
        ax.plot(self.training, self.G1_adv, label="G1 adv loss")
        ax.legend()
        ax.set_title("Adv losses (1st)")
        
        ax = axs[1,0]
        ax.plot(self.training, self.D2_loss, label="D2 adv loss")
        ax.plot(self.training, self.SR_adv, label="SR adv loss")
        ax.legend()
        ax.set_title("Adv losses (2nd)")
        
        ax = axs[0,1]
        ax.plot(self.training, self.cyc1, label = "cyc1")
        ax.plot(self.training, self.cyc2, label = "cyc2")
        ax.legend()
        ax.set_title("cyclic losses")
        
        ax = axs[1,1]
        ax.plot(self.training, self.tv1, label = "TV 1")
        ax.plot(self.training, self.tv2, label = "TV 2")
        ax.set_title("Total Variation losses")
        
        ax = axs[0,2]
        ax.plot(self.training, self.blur1, label = "blur loss 1")
        ax.legend()
        ax.set_title("Blur loss 1")
        fig.savefig("progress/log.png")
        
        fig, axs = plt.subplots(1,1)
        ax=axs
        ax.plot(self.ssim_eval_time, self.ssim)
        ax.set_title("SSIM evolution")
        fig.savefig("progress/ssim_evolution.png")
        
        plt.close("all")
        
    def train(self, epochs, batch_size=10, sample_interval=50):
        #every sample_interval batches, the model is saved and sample images are generated and saved
        
        start_time = datetime.datetime.now()
        def chop_microseconds(delta):
            #utility to help avoid printing the microseconds
            return delta - datetime.timedelta(microseconds=delta.microseconds)

        """ Adversarial loss ground truths for patchGAN discriminators"""
        
        # Calculate output shape of the patchGAN discriminators based on the target shape
        len_x = self.img_shape[0]
        len_y = self.img_shape[1]
        self.D1_out_shape = (len_x, len_y, 1)
        #define the adversarial ground truths for D1
        valid_D1 = np.ones((batch_size,) + self.D1_out_shape)
        fake_D1 = np.zeros((batch_size,) + self.D1_out_shape)
        print("valid D1: ", valid_D1.shape)
        #similarly for D2
        len_x = self.target_res[0]
        len_y = self.target_res[1]
        self.D2_out_shape = (len_x, len_y, 1)
        
        valid_D2 = np.ones((batch_size,) + self.D2_out_shape)
        fake_D2 = np.zeros((batch_size,) + self.D2_out_shape)
        
        #define an evaluator object to monitor the progress of the training
        dynamic_evaluator = evaluator(img_res=self.img_shape, SRscale = self.SRscale)
        for epoch in range(epochs):
            for batch, (img_x, img_y, img_Y) in enumerate(self.data_loader.load_batch(batch_size)):

                # Update the discriminators 

                # Make the appropriate generator translations for discriminator training
                fake_y = self.G1.predict(img_x) #translate x to denoised x
                fake_Y = self.SR.predict(fake_y) #translate denoised x to super-resolved x
                
                # Train the discriminators (original images = real / translated = fake)
                #we will need different adversarial ground truths for the discriminators
                D1_loss_real = self.D1.train_on_batch(img_y, valid_D1)
                D1_loss_fake = self.D1.train_on_batch(fake_y, fake_D1)
                D1_loss = 0.5 * np.add(D1_loss_real, D1_loss_fake)
                
                D2_loss_real = self.D2.train_on_batch(img_Y, valid_D2)
                D2_loss_fake = self.D2.train_on_batch(fake_Y, fake_D2)
                D2_loss = 0.5 * np.add(D2_loss_real, D2_loss_fake)
                
                # ------------------
                #  Train Generators
                # ------------------
                
                blur_img_x = self.blur.predict(img_x) #passes img_x through the blurring kernel to provide GT.
                
                # Train the combined models (all generators basically)
                cyclic_loss_1 = self.cyclic1.train_on_batch([img_x], [valid_D1, img_x, blur_img_x, fake_y])
                cyclic_loss_2 = self.cyclic2.train_on_batch([fake_y], [valid_D2, img_x, fake_Y])
                
                """update log values"""
                #save the training point (measured in epochs)
                self.training.append(round(epoch+batch/self.data_loader.n_batches, 3))
                #adversarial losses
                self.D1_loss.append(D1_loss[0])
                self.D2_loss.append(D2_loss[0])
                self.G1_adv.append(cyclic_loss_1[1])
                self.SR_adv.append(cyclic_loss_2[1])
                
                #1cycleGAN losses
                self.cyc1.append(cyclic_loss_1[2])
                self.blur1.append(cyclic_loss_1[3])
                self.tv1.append(cyclic_loss_1[4])
                
                #2nd cycleGan losses
                self.cyc2.append(cyclic_loss_2[2])
                self.tv2.append(cyclic_loss_2[3])
        
                
                elapsed_time = datetime.datetime.now() - start_time
                elapsed_time = chop_microseconds(elapsed_time)
                print("[elapsed time: %s][Epoch %d/%d] [Batch %d/%d] -- [D1_adv: %.3f D2_adv: %.3f] -- [G1_adv: %.3f - SR_adv: %.3f - cyc1: %.4f - cyc2: %.4f]" % (elapsed_time, epoch, epochs,
                      batch, self.data_loader.n_batches, D1_loss[0], D2_loss[0], cyclic_loss_1[1], cyclic_loss_2[1], cyclic_loss_1[2], cyclic_loss_2[2]))
                
                if batch % 20 == 0 and not(batch == 0 and epoch == 0):
                    """save the model"""
                    model_name="{}_{}.h5".format(epoch, batch)
                    self.SR.save("models/"+model_name)
                    print("Epoch: {} --- Batch: {} ---- saved".format(epoch, batch))
                    
                    dynamic_evaluator.model = [self.G1, self.SR]
                    dynamic_evaluator.epoch = epoch
                    dynamic_evaluator.batch = batch
                    dynamic_evaluator.perceptual_test(5)
                    
                    sample_mean_ssim = dynamic_evaluator.objective_test(batch_size=250)
                    print("Sample mean SSIM: -------------------  %05f   -------------------" % (sample_mean_ssim))
                    self.ssim_eval_time.append(round(epoch+batch/self.data_loader.n_batches, 3))
                    self.ssim.append(sample_mean_ssim)
                    
                    self.log()
Beispiel #6
0
import numpy as np
from model_architectures import SR
import matplotlib.pyplot as plt
from preprocessing import NormalizeData
import cv2 as cv
#this file is used to test the performance of a saved generator model

main_path = r"C:\\Users\\Giorgos\\Documents\\data\\dped\\iphone\\full_size_test_images"
data_loader = DataLoader(test_data_dir = main_path)

imgs=data_loader.load_data(domain="A", batch_size=10, patch_dimension = (128,128), is_testing=True)
print(imgs.shape)
imgs_tensor = K.variable(imgs)

#load the model
generator = SR(scale = 2, input_shape = imgs[0].shape, n_feats=256, n_resblocks=8, name = "Test_Generator")

model_path="C:\\Users\\Giorgos\\Documents\\Github\\USISR saved models\\2\\6_800.h5"
generator.load_weights(model_path)



for i in range(imgs.shape[0]):
    image=np.expand_dims(imgs[i,:,:,:], axis=0)
    fake_B_image = generator.predict(image)
    fake_B_image = NormalizeData(fake_B_image[0])
    fake_B_image=np.expand_dims(fake_B_image, axis=0)
    print(fake_B_image.shape)
    #plt.figure()
    original = NormalizeData(imgs[i,:,:,:])