Exemple #1
0
    def fit(self, batch_size=128, nb_epochs=100, save_history=True, history_fn="Model History.txt") -> Model:
        """
        Standard method to train any of the models.
        """

        samples_per_epoch = img_utils.image_count()
        val_count = img_utils.val_image_count()
        if self.model == None: self.create_model(batch_size=batch_size)

        callback_list = [callbacks.ModelCheckpoint(self.weight_path, monitor='val_PSNRLoss', save_best_only=True,
                                                   mode='max', save_weights_only=True, verbose=2)]
        if save_history:
            callback_list.append(HistoryCheckpoint(history_fn))

            if K.backend() == 'tensorflow':
                log_dir = './%s_logs/' % self.model_name
                tensorboard = TensorBoardBatch(log_dir, batch_size=batch_size)
                callback_list.append(tensorboard)

        print("Training model : %s" % (self.__class__.__name__))
        self.model.fit_generator(img_utils.image_generator(train_path, scale_factor=self.scale_factor,
                                                           small_train_images=self.type_true_upscaling,
                                                           batch_size=batch_size),
                                 steps_per_epoch=samples_per_epoch // batch_size + 1,
                                 epochs=nb_epochs, callbacks=callback_list,
                                 validation_data=img_utils.image_generator(validation_path,
                                                                           scale_factor=self.scale_factor,
                                                                           small_train_images=self.type_true_upscaling,
                                                                           batch_size=batch_size),
                                 validation_steps=val_count // batch_size + 1)

        return self.model
    def fit(self, batch_size=128, nb_epochs=100, save_history=True, history_fn="Model History.txt") -> Model:
        """
        Standard method to train any of the models.
        """

        samples_per_epoch = img_utils.image_count()
        val_count = img_utils.val_image_count()
        if self.model == None: self.create_model(batch_size=batch_size)

        callback_list = [callbacks.ModelCheckpoint(self.weight_path, monitor='val_PSNRLoss', save_best_only=True,
                                                   mode='max', save_weights_only=True)]
        if save_history: callback_list.append(HistoryCheckpoint(history_fn))

        print("Training model : %s" % (self.__class__.__name__))
        self.model.fit_generator(img_utils.image_generator(train_path, scale_factor=self.scale_factor,
                                                           small_train_images=self.type_true_upscaling,
                                                           batch_size=batch_size),
                                 samples_per_epoch=samples_per_epoch,
                                 nb_epoch=nb_epochs, callbacks=callback_list,
                                 validation_data=img_utils.image_generator(validation_path,
                                                                           scale_factor=self.scale_factor,
                                                                           small_train_images=self.type_true_upscaling,
                                                                           batch_size=batch_size),
                                 nb_val_samples=val_count)

        return self.model
Exemple #3
0
    def fit(self, scale_factor, weight_fn, batch_size=128, nb_epochs=100, small_train_images=False,
                                save_history=True, history_fn="Model History.txt") -> Model:
        """
        Standard method to train any of the models.
        """
        samples_per_epoch = img_utils.image_count()
        val_count = img_utils.val_image_count()
        if self.model == None: self.create_model(batch_size=batch_size)

        callback_list = [callbacks.ModelCheckpoint(weight_fn, monitor='val_PSNRLoss', save_best_only=True,  mode='max', save_weights_only=True),]
        if save_history: callback_list.append(HistoryCheckpoint(history_fn))

        print("Training model : %s" % (self.__class__.__name__))
        self.model.fit_generator(img_utils.image_generator(train_path, scale_factor=scale_factor,
                                                           small_train_images=small_train_images,
                                                           batch_size=batch_size),
                                 samples_per_epoch=samples_per_epoch,
                                 nb_epoch=nb_epochs, callbacks=callback_list,
                                 validation_data=img_utils.image_generator(validation_path,
                                                                          scale_factor=scale_factor,
                                                                          small_train_images=small_train_images,
                                                                          batch_size=batch_size),
                                 nb_val_samples=val_count)

        return self.model
Exemple #4
0
    def fit(self, nb_pretrain_samples=5000, batch_size=128, nb_epochs=100, disc_train_flip=0.1,
            save_history=True, history_fn="GAN SRCNN History.txt"):
        samples_per_epoch = img_utils.image_count()
        meanaxis = (0, 2, 3) if K.image_dim_ordering() == 'th' else (0, 1, 2)

        if self.model == None: self.create_model(mode='train', batch_size=batch_size)

        if os.path.exists(self.gen_weight_path) and os.path.exists(self.disc_weight_path):
            self.gen_model.load_weights(self.gen_weight_path)
            self.disc_model.load_weights(self.disc_weight_path)
            print("Pre-trained Generator and Discriminator network weights loaded")
        else:
            nb_train_samples = nb_pretrain_samples

            print('Pre-training on %d images' % (nb_train_samples))
            batchX, batchY = next(img_utils.image_generator(train_path, scale_factor=self.scale_factor,
                                                       small_train_images=self.type_true_upscaling,
                                                       batch_size=nb_train_samples))

            # [-1, 1] scale conversion from [0, 1]
            batchX = ((batchX * 255) - 127.5) / 127.5
            batchY = ((batchY * 255) - 127.5) / 127.5

            print("Pre-training Generator network")
            hist = self.gen_model.fit(batchX, batchY, batch_size, nb_epoch=200, verbose=2)
            print("Generator pretrain final PSNR : ", hist.history['PSNRLoss'][-1])

            print("Pre-training Discriminator network")

            genX = self.gen_model.predict(batchX, batch_size=batch_size)

            print('GenX Output mean (per channel) :', np.mean(genX, axis=meanaxis))
            print('BatchX mean (per channel) :', np.mean(batchX, axis=meanaxis))

            X = np.concatenate((genX, batchX))

            # Using soft and noisy labels
            if np.random.uniform() > disc_train_flip:
                # give correct classifications
                y = [0] * nb_train_samples + [1] * nb_train_samples
            else:
                # give wrong classifications (noisy labels)
                y = [1] * nb_train_samples + [0] * nb_train_samples

            y = np.asarray(y, dtype=np.float32).reshape(-1, 1)
            y = to_categorical(y, nb_classes=2)
            y = img_utils.smooth_gan_labels(y)

            hist = self.disc_model.fit(X, y, batch_size=batch_size,
                                       nb_epoch=1, verbose=0)

            print('Discriminator History :', hist.history)
            print()

        self.gen_model.save_weights(self.gen_weight_path, overwrite=True)
        self.disc_model.save_weights(self.disc_weight_path, overwrite=True)

        iteration = 0
        save_index = 1

        print("Training full model : %s" % (self.__class__.__name__))

        for i in range(nb_epochs):
            print("Epoch : %d" % (i + 1))
            print()

            for x, _ in img_utils.image_generator(train_path, scale_factor=self.scale_factor,
                                                  small_train_images=self.type_true_upscaling,  batch_size=batch_size):
                t1 = time.time()

                x = ((x * 255) - 127.5) / 127.5

                X_pred = self.gen_model.predict(x, batch_size)

                print("Input batchX mean (per channel) :", np.mean(x, axis=meanaxis))
                print("X_pred mean (per channel) :", np.mean(X_pred, axis=meanaxis))

                X = np.concatenate((X_pred, x))
                # Using soft and noisy labels
                if np.random.uniform() > disc_train_flip:
                    # give correct classifications
                    y_disc = [0] * nb_train_samples + [1] * nb_train_samples
                else:
                    # give wrong classifications (noisy labels)
                    y_disc = [1] * nb_train_samples + [0] * nb_train_samples

                y_disc = np.asarray(y_disc, dtype=np.float32).reshape(-1, 1)
                y_disc = to_categorical(y_disc, nb_classes=2)
                y_disc = img_utils.smooth_gan_labels(y_disc)

                hist = self.disc_model.fit(X, y_disc, verbose=0, batch_size=batch_size, nb_epoch=1)

                discriminator_loss = hist.history['loss'][0]
                discriminator_acc = hist.history['acc'][0]

                # Using soft labels
                y_model = [1] * nb_train_samples
                y_model = np.asarray(y_model, dtype=np.int).reshape(-1, 1)
                y_model = to_categorical(y_model, nb_classes=2)
                y_model = img_utils.smooth_gan_labels(y_model)

                hist = self.model.fit(x, y_model, batch_size, nb_epoch=1, verbose=0)
                generative_loss = hist.history['loss'][0]

                iteration += batch_size
                save_index += 1

                t2 = time.time()

                print("Iter : %d / %d | Time required : %0.2f seconds | Discriminator Loss / Acc : %0.6f / %0.3f | "
                      "Generative Loss : %0.6f" % (iteration, samples_per_epoch, t2 - t1,
                                                   discriminator_loss, discriminator_acc, generative_loss))

                # Validate at end of epoch
                if iteration >= samples_per_epoch:
                    print("Evaluating generator model...")
                    # losses = self.gen_model.evaluate_generator(generator=img_utils.image_generator(train_path,
                    #                                            scale_factor=self.scale_factor,
                    #                                            small_train_images=self.type_true_upscaling,
                    #                                            batch_size=batch_size),
                    #                                            val_samples=samples_per_epoch)
                    #
                    # print('Generator Loss (PSNR):', losses[-1])

                    self.evaluate('val_images/')

                # Save weights every 100 iterations
                if save_index % 100 == 0:
                    print("Saving generator weights")
                    self.gen_model.save_weights(self.weight_path, overwrite=True)

                if iteration >= samples_per_epoch:
                    break

            iteration = 0
            save_index = 1

        return self.model
                 student_output_tensor))  # l2 norm of difference
joint_model.add_loss(l2_weight * teacher_student_loss)

# perceptual loss
with K.name_scope('perceptual_loss'):
    perceptual_weight = 2.
    perceptual_loss = K.sum(
        K.square(
            gram_matrix(teacher_output_tensor) -
            gram_matrix(student_output_tensor)))
joint_model.add_loss(perceptual_weight * perceptual_loss)

joint_model.compile(optimizer='adam', loss=zero_loss)

# train student model using teacher model
samples_per_epoch = img_utils.image_count()
val_count = img_utils.val_image_count()

weight_path = 'weights/joint_model (%s) %dX.h5' % (teacher_model.model_name,
                                                   scale_factor)
history_fn = 'Joint_model_training.txt'

train_path = img_utils.output_path
validation_path = img_utils.validation_output_path
path_X = img_utils.output_path + "X/"
path_Y = img_utils.output_path + "y/"

callback_list = [
    ModelCheckpoint(weight_path,
                    monitor='val_loss',
                    save_best_only=True,