Example #1
0
    def test(self,
             refer_model=None,
             batch_size=4,
             datapath_test='./images/val_dir',
             crops_per_image=1,
             log_test_path="./images/test/"):
        """Trains the generator part of the network with MSE loss"""

        # Create data loaders
        loader = DataLoader(datapath_test, batch_size, self.height_hr,
                            self.width_hr, self.upscaling_factor,
                            crops_per_image)
        print(">> Ploting test images")
        plot_test_images(self,
                         loader,
                         datapath_test,
                         log_test_path,
                         0,
                         refer_model=refer_model)
Example #2
0
    def train_srgan(self,
        epochs, batch_size,
        dataname,
        datapath_train,
        datapath_validation=None,
        steps_per_validation=10,
        datapath_test=None,
        workers=40, max_queue_size=100,
        first_epoch=0,
        print_frequency=2,
        crops_per_image=2,
        log_weight_frequency=1000,
        log_weight_path='./data/weights/',
        log_tensorboard_path='./data/logs/',
        log_tensorboard_name='ESRGAN',
        log_tensorboard_update_freq=500,
        log_test_frequency=500,
        log_test_path="./images/samples/",
        ):
        """Train the ESRGAN network
        :param int epochs: how many epochs to train the network for
        :param str dataname: name to use for storing model weights etc.
        :param str datapath_train: path for te image files to use for training
        :param str datapath_test: path for the image files to use for testing / plotting
        :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference!
        :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int log_weight_path: where should network weights be saved
        :param int log_test_frequency: how often (in epochs) should testing & validation be performed
        :param str log_test_path: where should test results be saved
        :param str log_tensorboard_path: where should tensorflow logs be sent
        :param str log_tensorboard_name: what folder should tf logs be saved under
        """

        # Create train data loader
        loader = DataLoader(
            datapath_train, batch_size,
            self.height_hr, self.width_hr,
            self.upscaling_factor,
            crops_per_image
        )

        # Validation data loader
        if datapath_validation is not None:
            validation_loader = DataLoader(
                datapath_validation, batch_size,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                crops_per_image
            )
        print("Picture Loaders has been ready.")
        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(
            loader,
            use_multiprocessing=False,
            shuffle=True
        )
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()
        print("Data Enqueuer has been ready.")
        # Callback: tensorboard
        # if log_tensorboard_path:
        #     tensorboard = TensorBoard(
        #         log_dir=os.path.join(log_tensorboard_path, log_tensorboard_name),
        #         histogram_freq=0,
        #         batch_size=batch_size,
        #         write_graph=False,
        #         write_grads=False,
        #         update_freq=log_tensorboard_update_freq
        #     )
        #     tensorboard.set_model(self.srgan)
        # else:
        #     print(">> Not logging to tensorboard since no log_tensorboard_path is set")

        # Callback: format input value
        # def named_logs(model, logs):
        #     """Transform train_on_batch return value to dict expected by on_batch_end callback"""
        #     result = {}
        #     for l in zip(model.metrics_names, logs):
        #         result[l[0]] = l[1]
        #     return result

        # Shape of output from discriminator
        # disciminator_output_shape = list(self.discriminator.output_shape)
        # disciminator_output_shape[0] = batch_size
        # disciminator_output_shape = tuple(disciminator_output_shape)

        # # # VALID / FAKE targets for discriminator
        # real = np.ones(disciminator_output_shape)
        # fake = np.zeros(disciminator_output_shape)

        # Each epoch == "update iteration" as defined in the paper
        print_losses = {"G": [], "D": []}
        start_epoch = datetime.datetime.now()

        # Random images to go through
        idxs = np.random.randint(0, len(loader), epochs)

        # Loop through epochs / iterations
        for epoch in range(first_epoch, epochs + first_epoch):
            # Start epoch time
            if epoch % print_frequency == 1:
                start_epoch = datetime.datetime.now()

                # Train discriminator
            imgs_lr, imgs_hr = next(output_generator)
            generated_hr = self.generator.predict(imgs_lr)
            # SRGAN's loss (don't use them)
            # real_loss = self.discriminator.train_on_batch(imgs_hr, real)
            # fake_loss = self.discriminator.train_on_batch(generated_hr, fake)
            # discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train Relativistic Discriminator
            discriminator_loss = self.RaGAN.train_on_batch([imgs_hr, generated_hr], None)

            # Train generator
            # features_hr = self.vgg.predict(self.preprocess_vgg(imgs_hr))
            generator_loss = self.srgan.train_on_batch([imgs_lr, imgs_hr], None)

            # Callbacks
            # logs = named_logs(self.srgan, generator_loss)
            # tensorboard.on_epoch_end(epoch, logs)
            # print(generator_loss, discriminator_loss)
            # Save losses
            print_losses['G'].append(generator_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['G']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(self.srgan.metrics_names, g_avg_loss)
                print(self.RaGAN.metrics_names, d_avg_loss)
                print("\nEpoch {}/{} | Time: {}s\n>> Generator/GAN: {}\n>> Discriminator: {}".format(
                    epoch, epochs + first_epoch,
                    (datetime.datetime.now() - start_epoch).seconds,
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.srgan.metrics_names, g_avg_loss)]),
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.RaGAN.metrics_names, d_avg_loss)])
                ))
                print_losses = {"G": [], "D": []}
                # Run validation inference if specified
                # if datapath_validation:
                #     print(">> Running validation inference")
                #     validation_losses = self.generator.evaluate_generator(
                #         validation_loader,
                #         steps=steps_per_validation,
                #         use_multiprocessing=workers>1,
                #         workers=workers
                #     )
                #     print(">> Validation Losses: {}".format(
                #         ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.generator.metrics_names, validation_losses)])
                #     ))

            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch % log_test_frequency == 0:
                print(">> Ploting test images")
                plot_test_images(self, loader, datapath_test, log_test_path, epoch, refer_model=self.refer_model)

            # Check if we should save the network weights
            if log_weight_frequency and epoch % log_weight_frequency == 0:
                # Save the network weights
                print(">> Saving the network weights")
                self.save_weights(os.path.join(log_weight_path, dataname), epoch)
Example #3
0
    def train_generator(self,
        epochs, batch_size,
        workers=1,
        dataname='doctor',
        datapath_train='./images/train_dir',
        datapath_validation='./images/val_dir',
        datapath_test='./images/val_dir',
        steps_per_epoch=1000,
        steps_per_validation=1000,
        crops_per_image=2,
        log_weight_path='./data/weights/',
        log_tensorboard_path='./data/logs/',
        log_tensorboard_name='SR-RRDB-D',
        log_tensorboard_update_freq=1,
        log_test_path="./images/samples-d/"
        ):
        """Trains the generator part of the network with MSE loss"""

        # Create data loaders
        train_loader = DataLoader(
            datapath_train, batch_size,
            self.height_hr, self.width_hr,
            self.upscaling_factor,
            crops_per_image
        )
        test_loader = None
        if datapath_validation is not None:
            test_loader = DataLoader(
                datapath_validation, batch_size,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                crops_per_image
            )

        self.gen_lr = 3.2e-5
        for step in range(epochs // 10):
            self.compile_generator(self.generator)
            # Callback: tensorboard
            callbacks = []
            if log_tensorboard_path:
                tensorboard = TensorBoard(
                    log_dir=os.path.join(log_tensorboard_path, log_tensorboard_name),
                    histogram_freq=0,
                    batch_size=batch_size,
                    write_graph=False,
                    write_grads=False,
                    update_freq=log_tensorboard_update_freq
                )
                callbacks.append(tensorboard)
            else:
                print(">> Not logging to tensorboard since no log_tensorboard_path is set")

            # Callback: save weights after each epoch
            modelcheckpoint = ModelCheckpoint(
                os.path.join(log_weight_path, dataname + '_{}X.h5'.format(self.upscaling_factor)),
                monitor='PSNR',
                save_best_only=True,
                save_weights_only=True
            )
            callbacks.append(modelcheckpoint)

            # Callback: test images plotting
            if datapath_test is not None:
                testplotting = LambdaCallback(
                    on_epoch_end=lambda epoch, logs: plot_test_images(
                        self,
                        test_loader,
                        datapath_test,
                        log_test_path,
                        epoch + step * 10,
                        name='RRDB-D'
                    )
                )
                callbacks.append(testplotting)

            # Fit the model
            self.generator.fit_generator(
                train_loader,
                steps_per_epoch=steps_per_epoch,
                epochs=10,
                validation_data=test_loader,
                validation_steps=steps_per_validation,
                callbacks=callbacks,
                use_multiprocessing=workers > 1,
                workers=workers
            )
            self.generator.save('./data/weights/Doctor_gan(Step %dK).h5' % (step * 10 + 10))
            self.gen_lr /= 1.149
            print(step, self.gen_lr)
Example #4
0
    def train_esrgan(self, 
        epochs=None, batch_size=16, 
        modelname=None, 
        datapath_train=None,
        datapath_validation=None, 
        steps_per_validation=1000,
        datapath_test=None, 
        workers=4, max_queue_size=10,
        first_epoch=0,
        print_frequency=1,
        crops_per_image=2,
        log_weight_frequency=None, 
        log_weight_path='./model/', 
        log_tensorboard_path='./data/logs/',
        log_tensorboard_update_freq=10,
        log_test_frequency=500,
        log_test_path="./images/samples/", 
        media_type='i'        
    ):
        """Train the ESRGAN network

        :param int epochs: how many epochs to train the network for
        :param str modelname: name to use for storing model weights etc.
        :param str datapath_train: path for the image files to use for training
        :param str datapath_test: path for the image files to use for testing / plotting
        :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference!
        :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int log_weight_path: where should network weights be saved        
        :param int log_test_frequency: how often (in epochs) should testing & validation be performed
        :param str log_test_path: where should test results be saved
        :param str log_tensorboard_path: where should tensorflow logs be sent
        """

        
        
         # Create data loaders
        train_loader = DataLoader(
            datapath_train, batch_size,
            self.height_hr, self.width_hr,
            self.upscaling_factor,
            crops_per_image,
            media_type,
            self.channels,
            self.colorspace
        )

        # Validation data loader
        validation_loader = None 
        if datapath_validation is not None:
            validation_loader = DataLoader(
                datapath_validation, batch_size,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                crops_per_image,
                media_type,
                self.channels,
                self.colorspace
        )

        test_loader = None
        if datapath_test is not None:
            test_loader = DataLoader(
                datapath_test, 1,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                1,
                media_type,
                self.channels,
                self.colorspace
        )
    
        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(
            train_loader,
            use_multiprocessing=True,
            shuffle=True
        )
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()
        
        # Callback: tensorboard
        if log_tensorboard_path:
            tensorboard = TensorBoard(
                log_dir=os.path.join(log_tensorboard_path, modelname),
                histogram_freq=0,
                batch_size=batch_size,
                write_graph=True,
                write_grads=True,
                update_freq=log_tensorboard_update_freq
            )
            tensorboard.set_model(self.esrgan)
        else:
            print(">> Not logging to tensorboard since no log_tensorboard_path is set")

        # Learning rate scheduler
        def lr_scheduler(epoch, lr):
            factor = 0.5
            decay_step =  [50000,100000,200000,300000]  
            if epoch in decay_step and epoch:
                return lr * factor
            return lr
        lr_scheduler_gan = LearningRateScheduler(lr_scheduler, verbose=1)
        lr_scheduler_gan.set_model(self.esrgan)
        lr_scheduler_gen = LearningRateScheduler(lr_scheduler, verbose=0)
        lr_scheduler_gen.set_model(self.generator)
        lr_scheduler_dis = LearningRateScheduler(lr_scheduler, verbose=0)
        lr_scheduler_dis.set_model(self.discriminator)
        lr_scheduler_ra = LearningRateScheduler(lr_scheduler, verbose=0)
        lr_scheduler_ra.set_model(self.ra_discriminator)

        
        # Callback: format input value
        def named_logs(model, logs):
            """Transform train_on_batch return value to dict expected by on_batch_end callback"""
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        # Shape of output from discriminator
        disciminator_output_shape = list(self.ra_discriminator.output_shape)
        disciminator_output_shape[0] = batch_size
        disciminator_output_shape = tuple(disciminator_output_shape)

        # VALID / FAKE targets for discriminator
        real = np.ones(disciminator_output_shape)
        fake = np.zeros(disciminator_output_shape) 
               

        # Each epoch == "update iteration" as defined in the paper        
        print_losses = {"GAN": [], "D": []}
        start_epoch = datetime.datetime.now()
        
        # Random images to go through
        #idxs = np.random.randint(0, len(train_loader), epochs)        
        
        # Loop through epochs / iterations
        for epoch in range(first_epoch, int(epochs)+first_epoch):
            lr_scheduler_gan.on_epoch_begin(epoch)
            lr_scheduler_ra.on_epoch_begin(epoch)
            lr_scheduler_dis.on_epoch_begin(epoch)
            lr_scheduler_gen.on_epoch_begin(epoch)

            # Start epoch time
            if epoch % print_frequency == 0:
                print("\nEpoch {}/{}:".format(epoch+1, epochs+first_epoch))
                start_epoch = datetime.datetime.now()            

            # Train discriminator 
            self.discriminator.trainable = True
            self.ra_discriminator.trainable = True
            
            imgs_lr, imgs_hr = next(output_generator)
            generated_hr = self.generator.predict(imgs_lr)

            real_loss = self.ra_discriminator.train_on_batch([imgs_hr,generated_hr], real)
            #print("Real: ",real_loss)
            fake_loss = self.ra_discriminator.train_on_batch([generated_hr,imgs_hr], fake)
            #print("Fake: ",fake_loss)
            discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train generator
            self.discriminator.trainable = False
            self.ra_discriminator.trainable = False
            
            #for _ in tqdm(range(10),ncols=60,desc=">> Training generator"):
            imgs_lr, imgs_hr = next(output_generator)
            gan_loss = self.esrgan.train_on_batch([imgs_lr,imgs_hr], [imgs_hr,real,imgs_hr])
     
            # Callbacks
            logs = named_logs(self.esrgan, gan_loss)
            tensorboard.on_epoch_end(epoch, logs)
            

            # Save losses            
            print_losses['GAN'].append(gan_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['GAN']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(">> Time: {}s\n>> GAN: {}\n>> Discriminator: {}".format(
                    (datetime.datetime.now() - start_epoch).seconds,
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.esrgan.metrics_names, g_avg_loss)]),
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.discriminator.metrics_names, d_avg_loss)])
                ))
                print_losses = {"GAN": [], "D": []}

                # Run validation inference if specified
                if datapath_validation:
                    validation_losses = self.generator.evaluate_generator(
                        validation_loader,
                        steps=steps_per_validation,
                        use_multiprocessing=workers>1,
                        workers=workers
                    )
                    print(">> Validation Losses: {}".format(
                        ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.generator.metrics_names, validation_losses)])
                    ))                

            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch % log_test_frequency == 0:
                plot_test_images(self.generator, test_loader, datapath_test, log_test_path, epoch, modelname,
                channels = self.channels,colorspace=self.colorspace)

            # Check if we should save the network weights
            if log_weight_frequency and epoch % log_weight_frequency == 0:
                # Save the network weights
                self.save_weights(os.path.join(log_weight_path, modelname))
Example #5
0
    def train_generator(self,
        epochs=None, batch_size=None,
        workers=None,
        max_queue_size=None,
        modelname=None, 
        datapath_train=None,
        datapath_validation='../',
        datapath_test='../',
        steps_per_epoch=None,
        steps_per_validation=None,
        crops_per_image=None,
        print_frequency=None,
        log_weight_path='./model/', 
        log_tensorboard_path='./logs/',
        log_tensorboard_update_freq=None,
        log_test_path="./test/",
        media_type='i'
    ):
        """Trains the generator part of the network with MSE loss"""


        # Create data loaders
        train_loader = DataLoader(
            datapath_train, batch_size,
            self.height_hr, self.width_hr,
            self.upscaling_factor,
            crops_per_image,
            media_type,
            self.channels,
            self.colorspace
        )

        
        validation_loader = None 
        if datapath_validation is not None:
            validation_loader = DataLoader(
                datapath_validation, batch_size,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                crops_per_image,
                media_type,
                self.channels,
                self.colorspace
        )

        test_loader = None
        if datapath_test is not None:
            test_loader = DataLoader(
                datapath_test, 1,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                1,
                media_type,
                self.channels,
                self.colorspace
        )

        
        # Callback: tensorboard
        callbacks = []
        if log_tensorboard_path:
            tensorboard = TensorBoard(
                log_dir=os.path.join(log_tensorboard_path, modelname),
                histogram_freq=0,
                batch_size=batch_size,
                write_graph=True,
                write_grads=True,
                update_freq=log_tensorboard_update_freq
            )
            callbacks.append(tensorboard)
        else:
            print(">> Not logging to tensorboard since no log_tensorboard_path is set")

	    # Callback: Stop training when a monitored quantity has stopped improving
        earlystopping = EarlyStopping(
            monitor='val_loss', 
	        patience=500, verbose=1, 
	        restore_best_weights=True     
        )
        callbacks.append(earlystopping)
        
        # Callback: save weights after each epoch
        modelcheckpoint = ModelCheckpoint(
            os.path.join(log_weight_path, modelname + '_{}X.h5'.format(self.upscaling_factor)), 
            monitor='val_loss', 
            save_best_only=True, 
            save_weights_only=True
        )
        callbacks.append(modelcheckpoint)

        # Callback: Reduce lr when a monitored quantity has stopped improving
        reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
                                    patience=50, min_lr=1e-5,verbose=1)
        callbacks.append(reduce_lr)

        # Learning rate scheduler
        def lr_scheduler(epoch, lr):
            factor = 0.5
            decay_step = 100 #100 epochs * 2000 step per epoch = 2x1e5
            if epoch % decay_step == 0 and epoch:
                return lr * factor
            return lr
        lr_scheduler = LearningRateScheduler(lr_scheduler, verbose=1)
        callbacks.append(lr_scheduler)


        # Callback: save weights after each epoch
        modelcheckpoint = ModelCheckpoint(
            os.path.join(log_weight_path, modelname + '_{}X.h5'.format(self.upscaling_factor)), 
            monitor='val_loss', 
            save_best_only=True, 
            save_weights_only=True)
        callbacks.append(modelcheckpoint)
 
        
         # Callback: test images plotting
        if datapath_test is not None:
            testplotting = LambdaCallback(
                on_epoch_end=lambda epoch, logs: None if ((epoch+1) % print_frequency != 0 ) else plot_test_images(
                    self.generator,
                    test_loader,
                    datapath_test,
                    log_test_path,
                    epoch+1,
                    name=modelname,
                    channels=self.channels,
                    colorspace=self.colorspace))
        callbacks.append(testplotting)

        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(
            train_loader,
            use_multiprocessing=True
        )
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()

                            
        # Fit the model
        self.generator.fit_generator(
            output_generator,
            steps_per_epoch=steps_per_epoch,
            epochs=epochs,
            validation_data=validation_loader,
            validation_steps=steps_per_validation,
            callbacks=callbacks,
            use_multiprocessing=False, #workers>1 because single gpu
            workers=1
        )
Example #6
0
    def train(
            self,
            epochs=50,
            batch_size=8,
            steps_per_epoch=5,
            steps_per_validation=5,
            crops_per_image=4,
            print_frequency=5,
            log_tensorboard_update_freq=10,
            workers=4,
            max_queue_size=5,
            model_name='ESPCN',
            datapath_train='../../../videos_harmonic/MYANMAR_2160p/train/',
            datapath_validation='../../../videos_harmonic/MYANMAR_2160p/validation/',
            datapath_test='../../../videos_harmonic/MYANMAR_2160p/test/',
            log_weight_path='../model/',
            log_tensorboard_path='../logs/',
            log_test_path='../test/',
            media_type='i'):

        # Create data loaders
        train_loader = DataLoader(datapath_train, batch_size, self.height_hr,
                                  self.width_hr, self.upscaling_factor,
                                  crops_per_image, media_type, self.channels,
                                  self.colorspace)

        validation_loader = None
        if datapath_validation is not None:
            validation_loader = DataLoader(datapath_validation, batch_size,
                                           self.height_hr, self.width_hr,
                                           self.upscaling_factor,
                                           crops_per_image, media_type,
                                           self.channels, self.colorspace)

        test_loader = None
        if datapath_test is not None:
            test_loader = DataLoader(datapath_test, 1, self.height_hr,
                                     self.width_hr, self.upscaling_factor, 1,
                                     media_type, self.channels,
                                     self.colorspace)

        # Callback: tensorboard
        callbacks = []
        if log_tensorboard_path:
            tensorboard = TensorBoard(log_dir=os.path.join(
                log_tensorboard_path, model_name),
                                      histogram_freq=0,
                                      batch_size=batch_size,
                                      write_graph=True,
                                      write_grads=True,
                                      update_freq=log_tensorboard_update_freq)
            callbacks.append(tensorboard)
        else:
            print(
                ">> Not logging to tensorboard since no log_tensorboard_path is set"
            )

        # Callback: Stop training when a monitored quantity has stopped improving
        earlystopping = EarlyStopping(monitor='val_loss',
                                      patience=100,
                                      verbose=1,
                                      restore_best_weights=True)
        callbacks.append(earlystopping)

        # Callback: Reduce lr when a monitored quantity has stopped improving
        reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                      factor=0.5,
                                      patience=100,
                                      min_lr=1e-4,
                                      verbose=1)
        callbacks.append(reduce_lr)

        # Callback: save weights after each epoch
        modelcheckpoint = ModelCheckpoint(os.path.join(
            log_weight_path,
            model_name + '_{}X.h5'.format(self.upscaling_factor)),
                                          monitor='val_loss',
                                          save_best_only=True,
                                          save_weights_only=True)
        callbacks.append(modelcheckpoint)

        # Callback: test images plotting
        if datapath_test is not None:
            testplotting = LambdaCallback(
                on_epoch_end=lambda epoch, logs: None
                if ((epoch + 1) % print_frequency != 0) else plot_test_images(
                    self.model,
                    test_loader,
                    datapath_test,
                    log_test_path,
                    epoch + 1,
                    name=model_name,
                    channels=self.channels,
                    colorspace=self.colorspace))
        callbacks.append(testplotting)

        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(train_loader, use_multiprocessing=True)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()

        self.model.fit_generator(
            output_generator,
            steps_per_epoch=steps_per_epoch,
            epochs=epochs,
            validation_data=validation_loader,
            validation_steps=steps_per_validation,
            callbacks=callbacks,
            #shuffle=True,
            use_multiprocessing=False,  #workers>1 
            workers=1  #workers
        )
Example #7
0
    def train(self,
              epochs,
              dataname,
              datapath,
              batch_size=1,
              test_images=None,
              test_frequency=50,
              test_path="./images/samples/",
              weight_frequency=None,
              weight_path='./data/weights/',
              print_frequency=1):
        """Train the SRGAN network

        :param int epochs: how many epochs to train the network for
        :param str dataname: name to use for storing model weights etc.
        :param str datapath: path for the image files to use for training
        :param int batch_size: how large mini-batches to use
        :param list test_images: list of image paths to perform testing on
        :param int test_frequency: how often (in epochs) should testing be performed
        :param str test_path: where should test results be saved
        :param int weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int weight_path: where should network weights be saved
        :param int print_frequency: how often (in epochs) to print progress to terminal
        """

        # Create data loader
        loader = DataLoader(datapath, self.height_hr, self.width_hr,
                            self.height_lr, self.width_lr,
                            self.upscaling_factor)

        # Shape of output from discriminator
        disciminator_output_shape = list(self.discriminator.output_shape)
        disciminator_output_shape[0] = batch_size
        disciminator_output_shape = tuple(disciminator_output_shape)

        # VALID / FAKE targets for discriminator
        real = np.ones(disciminator_output_shape)
        fake = np.zeros(disciminator_output_shape)

        # Each epoch == "update iteration" as defined in the paper
        losses = []
        for epoch in range(epochs):

            # Start epoch time
            if epoch % (print_frequency + 1) == 0:
                start_epoch = datetime.datetime.now()

            # Train discriminator
            imgs_hr, imgs_lr = loader.load_batch(batch_size)
            generated_hr = self.generator.predict(imgs_lr)
            real_loss = self.discriminator.train_on_batch(imgs_hr, real)
            fake_loss = self.discriminator.train_on_batch(generated_hr, fake)
            discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train generator
            imgs_hr, imgs_lr = loader.load_batch(batch_size)
            features_hr = self.vgg.predict(imgs_hr)
            generator_loss = self.srgan.train_on_batch([imgs_lr, imgs_hr],
                                                       [real, features_hr])

            # Save losses
            losses.append({
                'generator': generator_loss,
                'discriminator': discriminator_loss
            })

            # Plot the progress
            if epoch % print_frequency == 0:
                print(
                    "Epoch {}/{} | Time: {}s\n>> Generator: {}\n>> Discriminator: {}\n"
                    .format(
                        epoch, epochs,
                        (datetime.datetime.now() - start_epoch).seconds,
                        ", ".join([
                            "{}={:.3e}".format(k, v) for k, v in zip(
                                self.srgan.metrics_names, generator_loss)
                        ]), ", ".join([
                            "{}={:.3e}".format(k, v)
                            for k, v in zip(self.discriminator.metrics_names,
                                            discriminator_loss)
                        ])))

            # If test images are supplied, show them to the user
            if test_images and epoch % test_frequency == 0:
                plot_test_images(self, loader, test_images, test_path, epoch)

            # Check if we should save the network weights
            if weight_frequency and epoch % weight_frequency == 0:

                # Save the network weights
                self.save_weights(os.path.join(weight_path, dataname))

                # Save the recorded losses
                pickle.dump(
                    losses,
                    open(os.path.join(weight_path, dataname + '_losses.p'),
                         'wb'))
Example #8
0
    def train_srgan(self,
        epochs, batch_size,
        dataname,
        datapath_train,
        datapath_validation=None,
        steps_per_validation=10,
        datapath_test=None,
        workers=40, max_queue_size=100,
        first_epoch=0,
        print_frequency=2,
        crops_per_image=2,
        log_weight_frequency=1000,
        log_weight_path='./data/weights/',
        log_tensorboard_path='./data/logs/',
        log_tensorboard_name='SRGAN',
        log_tensorboard_update_freq=500,
        log_test_frequency=500,
        log_test_path="./images/samples/",
        ):

        # Create train data loader
        loader = DataLoader(
            datapath_train, batch_size,
            self.height_hr, self.width_hr,
            self.upscaling_factor,
            crops_per_image
        )

        # Validation data loader
        if datapath_validation is not None:
            validation_loader = DataLoader(
                datapath_validation, batch_size,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                crops_per_image
            )
        print("Picture Loaders has been ready.")
        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(
            loader,
            use_multiprocessing=False,
            shuffle=True
        )
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()
        print("Data Enqueuer has been ready.")

        print_losses = {"G": [], "D": []}
        start_epoch = datetime.datetime.now()

        # Random images to go through
        idxs = np.random.randint(0, len(loader), epochs)

        # Loop through epochs / iterations
        for epoch in range(first_epoch, epochs + first_epoch):
            # Start epoch time
            if epoch % print_frequency == 1:
                start_epoch = datetime.datetime.now()

                # Train discriminator
            imgs_lr, imgs_hr = next(output_generator)
            generated_hr = self.generator.predict(imgs_lr)
            # SRGAN's loss (don't use them)
            # real_loss = self.discriminator.train_on_batch(imgs_hr, real)
            # fake_loss = self.discriminator.train_on_batch(generated_hr, fake)
            # discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train Relativistic Discriminator
            discriminator_loss = self.RaGAN.train_on_batch([imgs_hr, generated_hr], None)

            # Train generator
            # features_hr = self.vgg.predict(self.preprocess_vgg(imgs_hr))
            generator_loss = self.srgan.train_on_batch([imgs_lr, imgs_hr], None)

            # Callbacks
            # logs = named_logs(self.srgan, generator_loss)
            # tensorboard.on_epoch_end(epoch, logs)
            # print(generator_loss, discriminator_loss)
            # Save losses
            print_losses['G'].append(generator_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['G']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(self.srgan.metrics_names, g_avg_loss)
                print(self.RaGAN.metrics_names, d_avg_loss)
                print("\nEpoch {}/{} | Time: {}s\n>> Generator/GAN: {}\n>> Discriminator: {}".format(
                    epoch, epochs + first_epoch,
                    (datetime.datetime.now() - start_epoch).seconds,
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.srgan.metrics_names, g_avg_loss)]),
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.RaGAN.metrics_names, d_avg_loss)])
                ))
                print_losses = {"G": [], "D": []}
            
            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch % log_test_frequency == 0:
                print(">> Ploting test images")
                plot_test_images(self, loader, datapath_test, log_test_path, epoch, refer_model=self.refer_model)

            # Check if we should save the network weights
            if log_weight_frequency and epoch % log_weight_frequency == 0:
                # Save the network weights
                print(">> Saving the network weights")
                self.save_weights(os.path.join(log_weight_path, dataname), epoch)