Exemple #1
0
def _test_loop(path, batch_size, datagen, img_height, img_width, iteration,
               large_img_height, large_img_width, model, total_psnr, prefix,
               nb_images):
    for x in datagen.flow_from_directory(path,
                                         class_mode=None,
                                         batch_size=batch_size,
                                         target_size=(large_img_width,
                                                      large_img_height)):
        t1 = time.time()

        # resize images
        x_temp = x.copy()
        x_temp = x_temp.transpose((0, 2, 3, 1))

        x_generator = np.empty((batch_size, img_width, img_height, 3))

        for j in range(batch_size):
            img = imresize(x_temp[j], (img_width, img_height))
            x_generator[j, :, :, :] = img

        x_generator = x_generator.transpose((0, 3, 1, 2))

        output_image_batch = model.predict_on_batch(x_generator)

        average_psnr = 0.0
        for x_i in range(batch_size):
            average_psnr += psnr(
                x[x_i],
                np.clip(output_image_batch[x_i] * 255, 0, 255) / 255.)
            total_psnr += average_psnr

        average_psnr /= batch_size

        iteration += batch_size
        t2 = time.time()

        print(
            "Time required : %0.2f. Average validation PSNR over %d samples = %0.2f"
            % (t2 - t1, batch_size, average_psnr))

        for x_i in range(batch_size):
            real_path = base_test_images + prefix + "_iteration_%d_num_%d_real_.png" % (
                iteration, x_i + 1)
            generated_path = base_test_images + prefix + "_iteration_%d_num_%d_generated.png" % (
                iteration, x_i + 1)

            val_x = x[x_i].copy() * 255.
            val_x = val_x.transpose((1, 2, 0))
            val_x = np.clip(val_x, 0, 255).astype('uint8')

            output_image = output_image_batch[x_i] * 255
            output_image = output_image.transpose((1, 2, 0))
            output_image = np.clip(output_image, 0, 255).astype('uint8')

            imsave(real_path, val_x)
            imsave(generated_path, output_image)

        if iteration >= nb_images:
            break
    return total_psnr
def _test_loop(path, batch_size, datagen, img_height, img_width, iteration, large_img_height, large_img_width, model,
               total_psnr, prefix, nb_images):
    for x in datagen.flow_from_directory(path, class_mode=None, batch_size=batch_size,
                                         target_size=(large_img_width, large_img_height)):
        t1 = time.time()

        # resize images
        x_temp = x.copy()
        x_temp = x_temp.transpose((0, 2, 3, 1))

        x_generator = np.empty((batch_size, img_width, img_height, 3))

        for j in range(batch_size):
            img = imresize(x_temp[j], (img_width, img_height))
            x_generator[j, :, :, :] = img

        x_generator = x_generator.transpose((0, 3, 1, 2))

        output_image_batch = model.predict_on_batch(x_generator)

        average_psnr = 0.0
        for x_i in range(batch_size):
            average_psnr += psnr(x[x_i], output_image_batch[x_i] / 255.)
            total_psnr += average_psnr

        average_psnr /= batch_size

        iteration += batch_size
        t2 = time.time()

        print("Time required : %0.2f. Average validation PSNR over %d samples = %0.2f" %
              (t2 - t1, batch_size, average_psnr))

        for x_i in range(batch_size):
            real_path = base_test_images + prefix + "_iteration_%d_num_%d_real_.png" % (iteration, x_i + 1)
            generated_path = base_test_images + prefix + "_iteration_%d_num_%d_generated.png" % (iteration, x_i + 1)

            val_x = x[x_i].copy() * 255.
            val_x = val_x.transpose((1, 2, 0))
            val_x = np.clip(val_x, 0, 255).astype('uint8')

            output_image = output_image_batch[x_i]
            output_image = output_image.transpose((1, 2, 0))
            output_image = np.clip(output_image, 0, 255).astype('uint8')

            imsave(real_path, val_x)
            imsave(generated_path, output_image)

        if iteration >= nb_images:
            break
    return total_psnr
Exemple #3
0
def validate(loader, model, epoch, d1, d2, blind, noise_level):
    val_psnr = 0
    val_ssim = 0
    val_l1 = 0
    model.train(False)

    k1 = model.weight[0].unsqueeze(0).expand(loader.batch_size, -1, -1, -1)
    k2 = model.weight[1].unsqueeze(0).expand(loader.batch_size, -1, -1, -1)
    d1 = d1.expand(loader.batch_size, -1, -1, -1)
    d2 = d2.expand(loader.batch_size, -1, -1, -1)

    # pre-create noise levels
    if blind:
        nls = np.linspace(0.5, noise_level, len(loader))
    else:
        nls = noise_level * np.ones(len(loader))

    with torch.no_grad():
        for i, data in tqdm.tqdm(enumerate(loader)):
            x, y, k, d = data
            x = x.to(device)
            y = y.to(device)
            k = k.to(device)
            d = d.to(device)

            nl = nls[i] / 255
            y += nl * torch.randn_like(y)
            y = y.clamp(0, 1)

            hat_x = model(y, k, d, k1, k2, d1, d2)[-1]
            hat_x.clamp_(0, 1)

            hat_x = utils.crop_valid(hat_x, k)
            x = utils.crop_valid(x, k)
            y = utils.crop_valid(y, k)

            val_psnr += loss.psnr(hat_x, x)
            val_ssim += loss.ssim(hat_x, x)
            val_l1 += F.l1_loss(hat_x, x).item()

    return val_psnr / len(loader), val_ssim / len(loader), val_l1 / len(loader)
Exemple #4
0
def validate(loader, model, epoch, d1, d2, blind, noise_level):
    val_psnr = 0
    val_ssim = 0
    val_l1 = 0
    model.train(False)

    k1 = model.weight[0].unsqueeze(0).expand(loader.batch_size, -1, -1, -1)
    k2 = model.weight[1].unsqueeze(0).expand(loader.batch_size, -1, -1, -1)
    d1 = d1.expand(loader.batch_size, -1, -1, -1)
    d2 = d2.expand(loader.batch_size, -1, -1, -1)

    # pre-create noise levels
    if blind:
        nls = np.linspace(0.5, noise_level, len(loader))
    else:
        nls = noise_level * np.ones(len(loader))

    with torch.no_grad():
        for i, data in tqdm.tqdm(enumerate(loader)):
            x, y, mag, ori = data
            x = x.to(device)
            y = y.to(device)
            mag = mag.to(device)
            ori = ori.to(device)
            ori = (90 - ori).add(360).fmod(180)

            labels = utils.get_labels(mag, ori)
            ori = ori * np.pi / 180

            nl = nls[i] / 255
            y += nl * torch.randn_like(y)
            y = y.clamp(0, 1)

            hat_x = model(y, mag, ori, labels, k1, k2, d1, d2)[-1]
            hat_x.clamp_(0, 1)

            val_psnr += loss.psnr(hat_x, x)
            val_ssim += loss.ssim(hat_x, x)
            val_l1 += F.l1_loss(hat_x, x).item()

    return val_psnr / len(loader), val_ssim / len(loader), val_l1 / len(loader)
Exemple #5
0
    def eval(self):
        self.model.eval()

        psnr_losses = []
        ssim_losses = []

        with tqdm(total=len(self.valloader)) as t:
            # show batch evaluate process
            t.set_description('evaluating...')

            for LR, HR in self.valloader:
                LR = LR.to(self.device)
                HR = HR.to(self.device)

                output = self.model(LR)

                # # only calculation on y channel or gray
                # if self.r_mode == 'RGB' and self.img_channels == 3:
                #     output = rgb2y_tensor(output)
                #     HR = rgb2y_tensor(HR)

                ssim_loss = ssim(output, HR)
                psnr_loss = psnr(output, HR)

                # save losses
                ssim_losses.append(ssim_loss)
                psnr_losses.append(psnr_loss)

                t.update()

            avg_ssim = torch.stack(ssim_losses, dim=0).mean().item()
            avg_psnr = torch.stack(psnr_losses, dim=0).mean().item()

            t.set_postfix(avg_psnr=f'{avg_psnr:.010f}',
                          avg_ssim=f'{avg_ssim:.010f}')
            t.set_description('evaluate')

        self.model.train()
        return avg_ssim, avg_psnr
    def _train_model(self,
                     image_dir,
                     nb_images=80000,
                     nb_epochs=10,
                     pre_train_srgan=False,
                     pre_train_discriminator=False,
                     load_generative_weights=False,
                     load_discriminator_weights=False,
                     save_loss=True,
                     disc_train_flip=0.1):

        assert self.img_width >= 16, "Minimum image width must be at least 16"
        assert self.img_height >= 16, "Minimum image height must be at least 16"

        if load_generative_weights:
            try:
                self.generative_model_.load_weights(
                    self.generative_network.sr_weights_path)
                print("Generator weights loaded.")
            except:
                print("Could not load generator weights.")

        if load_discriminator_weights:
            try:
                self.discriminative_network.load_gan_weights(self.srgan_model_)
                print("Discriminator weights loaded.")
            except:
                print("Could not load discriminator weights.")

        datagen = ImageDataGenerator(rescale=1. / 255)
        img_width = self.img_width * 4
        img_height = self.img_height * 4

        early_stop = False
        iteration = 0
        prev_improvement = -1

        if save_loss:
            if pre_train_srgan:
                loss_history = {
                    'generator_loss': [],
                    'val_psnr': [],
                }
            elif pre_train_discriminator:
                loss_history = {
                    'discriminator_loss': [],
                    'discriminator_acc': [],
                }
            else:
                loss_history = {
                    'discriminator_loss': [],
                    'discriminator_acc': [],
                    'generator_loss': [],
                    'val_psnr': [],
                }

        y_vgg_dummy = np.zeros((self.batch_size * 2, 3, img_width // 32,
                                img_height // 32))  # 5 Max Pools = 2 ** 5 = 32

        print("Training SRGAN network")
        for i in range(nb_epochs):
            print()
            print("Epoch : %d" % (i + 1))

            for x in datagen.flow_from_directory(image_dir,
                                                 class_mode=None,
                                                 batch_size=self.batch_size,
                                                 target_size=(img_width,
                                                              img_height)):
                try:
                    t1 = time.time()

                    if not pre_train_srgan and not pre_train_discriminator:
                        x_vgg = x.copy() * 255  # VGG input [0 - 255 scale]

                    # resize images
                    x_temp = x.copy()
                    x_temp = x_temp.transpose((0, 2, 3, 1))

                    x_generator = np.empty(
                        (self.batch_size, self.img_width, self.img_height, 3))

                    for j in range(self.batch_size):
                        img = gaussian_filter(x_temp[j], sigma=0.1)
                        img = imresize(img, (self.img_width, self.img_height),
                                       interp='bicubic')
                        x_generator[j, :, :, :] = img

                    x_generator = x_generator.transpose((0, 3, 1, 2))

                    if iteration % 50 == 0 and iteration != 0 and not pre_train_discriminator:
                        print("Validation image..")
                        output_image_batch = self.generative_network.get_generator_output(
                            x_generator, self.srgan_model_)
                        if type(output_image_batch) == list:
                            output_image_batch = output_image_batch[0]

                        mean_axis = (
                            0, 2,
                            3) if K.image_dim_ordering() == 'th' else (0, 1, 2)

                        average_psnr = 0.0

                        print(
                            'gen img mean :',
                            np.mean(output_image_batch / 255., axis=mean_axis))
                        print('val img mean :', np.mean(x, axis=mean_axis))

                        for x_i in range(self.batch_size):
                            average_psnr += psnr(
                                x[x_i],
                                np.clip(output_image_batch[x_i], 0, 255) /
                                255.)

                        average_psnr /= self.batch_size

                        if save_loss:
                            loss_history['val_psnr'].append(average_psnr)

                        iteration += self.batch_size
                        t2 = time.time()

                        print(
                            "Time required : %0.2f. Average validation PSNR over %d samples = %0.2f"
                            % (t2 - t1, self.batch_size, average_psnr))

                        for x_i in range(self.batch_size):
                            real_path = "val_images/epoch_%d_iteration_%d_num_%d_real_.png" % (
                                i + 1, iteration, x_i + 1)
                            generated_path = "val_images/epoch_%d_iteration_%d_num_%d_generated.png" % (
                                i + 1, iteration, x_i + 1)

                            val_x = x[x_i].copy() * 255.
                            val_x = val_x.transpose((1, 2, 0))
                            val_x = np.clip(val_x, 0, 255).astype('uint8')

                            output_image = output_image_batch[x_i]
                            output_image = output_image.transpose((1, 2, 0))
                            output_image = np.clip(output_image, 0,
                                                   255).astype('uint8')

                            imsave(real_path, val_x)
                            imsave(generated_path, output_image)
                        '''
                        Don't train of validation images for now.

                        Note that if nb_epochs > 1, there is a chance that
                        validation images may be used for training purposes as well.

                        In that case, this isn't strictly a validation measure, instead of
                        just a check to see what the network has learned.
                        '''
                        continue

                    if pre_train_srgan:
                        # Train only generator + vgg network

                        # Use custom bypass_fit to bypass the check for same input and output batch size
                        hist = bypass_fit(self.srgan_model_,
                                          [x_generator, x * 255],
                                          y_vgg_dummy,
                                          batch_size=self.batch_size,
                                          nb_epoch=1,
                                          verbose=0)
                        sr_loss = hist.history['loss'][0]

                        if save_loss:
                            loss_history['generator_loss'].extend(
                                hist.history['loss'])

                        if prev_improvement == -1:
                            prev_improvement = sr_loss

                        improvement = (prev_improvement -
                                       sr_loss) / prev_improvement * 100
                        prev_improvement = sr_loss

                        iteration += self.batch_size
                        t2 = time.time()

                        print(
                            "Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | "
                            "Generative Loss : %0.2f" %
                            (iteration, nb_images, improvement, t2 - t1,
                             sr_loss))
                    elif pre_train_discriminator:
                        # Train only discriminator
                        X_pred = self.generative_model_.predict(
                            x_generator, self.batch_size)

                        X = np.concatenate((X_pred, x * 255))

                        # Using soft and noisy labels
                        if np.random.uniform() > disc_train_flip:
                            # give correct classifications
                            y_gan = [0] * self.batch_size + [
                                1
                            ] * self.batch_size
                        else:
                            # give wrong classifications (noisy labels)
                            y_gan = [1] * self.batch_size + [
                                0
                            ] * self.batch_size

                        y_gan = np.asarray(y_gan, dtype=np.int).reshape(-1, 1)
                        y_gan = to_categorical(y_gan, nb_classes=2)
                        y_gan = smooth_gan_labels(y_gan)

                        hist = self.discriminative_model_.fit(
                            X,
                            y_gan,
                            batch_size=self.batch_size,
                            nb_epoch=1,
                            verbose=0)

                        discriminator_loss = hist.history['loss'][-1]
                        discriminator_acc = hist.history['acc'][-1]

                        if save_loss:
                            loss_history['discriminator_loss'].extend(
                                hist.history['loss'])
                            loss_history['discriminator_acc'].extend(
                                hist.history['acc'])

                        if prev_improvement == -1:
                            prev_improvement = discriminator_loss

                        improvement = (prev_improvement - discriminator_loss
                                       ) / prev_improvement * 100
                        prev_improvement = discriminator_loss

                        iteration += self.batch_size
                        t2 = time.time()

                        print(
                            "Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | "
                            "Discriminator Loss / Acc : %0.4f / %0.2f" %
                            (iteration, nb_images, improvement, t2 - t1,
                             discriminator_loss, discriminator_acc))

                    else:
                        # Train only discriminator, disable training of srgan
                        self.discriminative_network.set_trainable(
                            self.srgan_model_, value=True)
                        self.generative_network.set_trainable(
                            self.srgan_model_, value=False)

                        # Use custom bypass_fit to bypass the check for same input and output batch size
                        # hist = bypass_fit(self.srgan_model_, [x_generator, x * 255, x_vgg],
                        #                          [y_gan, y_vgg_dummy],
                        #                          batch_size=self.batch_size, nb_epoch=1, verbose=0)

                        X_pred = self.generative_model_.predict(
                            x_generator, self.batch_size)

                        X = np.concatenate((X_pred, x * 255))

                        # Using soft and noisy labels
                        if np.random.uniform() > disc_train_flip:
                            # give correct classifications
                            y_gan = [0] * self.batch_size + [
                                1
                            ] * self.batch_size
                        else:
                            # give wrong classifications (noisy labels)
                            y_gan = [1] * self.batch_size + [
                                0
                            ] * self.batch_size

                        y_gan = np.asarray(y_gan, dtype=np.int).reshape(-1, 1)
                        y_gan = to_categorical(y_gan, nb_classes=2)
                        y_gan = smooth_gan_labels(y_gan)

                        hist1 = self.discriminative_model_.fit(
                            X,
                            y_gan,
                            verbose=0,
                            batch_size=self.batch_size,
                            nb_epoch=1)

                        discriminator_loss = hist1.history['loss'][-1]

                        # Train only generator, disable training of discriminator
                        self.discriminative_network.set_trainable(
                            self.srgan_model_, value=False)
                        self.generative_network.set_trainable(
                            self.srgan_model_, value=True)

                        # Using soft labels
                        y_model = [1] * self.batch_size
                        y_model = np.asarray(y_model,
                                             dtype=np.int).reshape(-1, 1)
                        y_model = to_categorical(y_model, nb_classes=2)
                        y_model = smooth_gan_labels(y_model)

                        # Use custom bypass_fit to bypass the check for same input and output batch size
                        hist2 = bypass_fit(self.srgan_model_,
                                           [x_generator, x, x_vgg],
                                           [y_model, y_vgg_dummy],
                                           batch_size=self.batch_size,
                                           nb_epoch=1,
                                           verbose=0)

                        generative_loss = hist2.history['loss'][0]

                        if save_loss:
                            loss_history['discriminator_loss'].extend(
                                hist1.history['loss'])
                            loss_history['discriminator_acc'].extend(
                                hist1.history['acc'])
                            loss_history['generator_loss'].extend(
                                hist2.history['loss'])

                        if prev_improvement == -1:
                            prev_improvement = discriminator_loss

                        improvement = (prev_improvement - discriminator_loss
                                       ) / prev_improvement * 100
                        prev_improvement = discriminator_loss

                        iteration += self.batch_size
                        t2 = time.time()
                        print(
                            "Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | "
                            "Discriminator Loss : %0.3f | Generative Loss : %0.3f"
                            % (iteration, nb_images, improvement, t2 - t1,
                               discriminator_loss, generative_loss))

                    if iteration % 1000 == 0 and iteration != 0:
                        print("Saving model weights.")
                        # Save predictive (SR network) weights
                        self._save_model_weights(pre_train_srgan,
                                                 pre_train_discriminator)
                        self._save_loss_history(loss_history, pre_train_srgan,
                                                pre_train_discriminator,
                                                save_loss)

                    if iteration >= nb_images:
                        break

                except KeyboardInterrupt:
                    print("Keyboard interrupt detected. Stopping early.")
                    early_stop = True
                    break

            iteration = 0

            if early_stop:
                break

        print("Finished training SRGAN network. Saving model weights.")
        # Save predictive (SR network) weights
        self._save_model_weights(pre_train_srgan, pre_train_discriminator)
        self._save_loss_history(loss_history, pre_train_srgan,
                                pre_train_discriminator, save_loss)
Exemple #7
0
def _test_loop(path, batch_size, datagen, img_height, img_width, iteration,
               large_img_height, large_img_width, model, total_psnr, prefix,
               nb_images, normalized):
    """

    :param path: 数据集地址
    :param batch_size: 每个iteration生成图片数
    :param datagen: 图片增强迭代器
    :param img_height:
    :param img_width:
    :param iteration:
    :param large_img_height:
    :param large_img_width:
    :param model: 网络模型
    :param total_psnr: 总pnsr值
    :param prefix: 文件保存位置
    :param nb_images: 测试图片数量
    :param normalized: 预测模型在训练时是否将图片归一化
    :return:
    """
    for x in datagen.flow_from_directory(path,
                                         class_mode=None,
                                         batch_size=batch_size,
                                         target_size=(large_img_width,
                                                      large_img_height)):
        t1 = time.time()

        # resize images
        x_temp = x.copy()
        #x_temp = x_temp.transpose((0, 2, 3, 1))

        x_generator = np.empty(
            (batch_size, large_img_width, large_img_height, 3))

        for j in range(batch_size):

            '先对图片进行下取样,再进行bicubic插值'

            img = imresize(x_temp[j], (img_width, img_height))
            img = imresize(img, (large_img_width, large_img_height),
                           interp='bicubic')
            '归一化情况下须要除以255'
            if normalized: x_generator[j, :, :, :] = img / 255.
            else: x_generator[j, :, :, :] = img

        output_image_batch = model.predict_on_batch(x_generator)

        average_psnr = 0.0
        for x_i in range(batch_size):
            if normalized:
                average_psnr += psnr(x[x_i], output_image_batch[x_i])
            else:
                average_psnr += psnr(x[x_i], output_image_batch[x_i] / 255.)
            total_psnr += average_psnr

        average_psnr /= batch_size

        iteration += batch_size
        t2 = time.time()

        print(
            "Time required : %0.2f. Average validation PSNR over %d samples = %0.2f"
            % (t2 - t1, batch_size, average_psnr))

        for x_i in range(batch_size):
            '保存验证集中的图片'
            real_path = base_test_images + prefix + "_iteration_%d_num_%d_real_.png" % (
                iteration, x_i + 1)
            bicubic_path = base_test_images + prefix + "_iteration_%d_num_%d_bicubic.png" % (
                iteration, x_i + 1)
            generated_path = base_test_images + prefix + "_iteration_%d_num_%d_generated.png" % (
                iteration, x_i + 1)

            val_x = x[x_i].copy() * 255.
            val_x = np.clip(val_x, 0, 255).astype('uint8')

            if normalized: input_img = x_generator[x_i].copy() * 255
            else: input_img = x_generator[x_i].copy()
            input_img = np.clip(input_img, 0, 255).astype('uint8')

            if normalized: output_image = output_image_batch[x_i] * 255
            else: output_image = output_image_batch[x_i]
            output_image = np.clip(output_image, 0, 255).astype('uint8')

            imsave(real_path, val_x)
            imsave(bicubic_path, input_img)
            imsave(generated_path, output_image)

        if iteration >= nb_images:
            break
    return total_psnr
Exemple #8
0
def test_individul(prefix='Set5', scale=4, mode='rgb', normalized=True):

    pic_path = os.path.join(set5_path[0:len(set5_path) - 6], prefix)
    pic_path = os.path.join(pic_path, prefix)

    total_psnr_bicubic = 0
    total_psnr_generated = 0
    total_ssim_bicubic = 0
    total_ssim_generated = 0

    for file in os.listdir(pic_path):

        real_rgb = imread(os.path.join(pic_path, file))
        '图片数据预处理'
        img_shape = real_rgb.shape
        img_width = int(img_shape[0] / scale) * scale
        img_height = int(img_shape[1] / scale) * scale
        real_rgb = real_rgb[0:img_width, 0:img_height]
        '可能出现灰度图的情况,通道只有2'
        if len(img_shape) == 2:
            real_temp = np.empty((img_width, img_height, 3))
            for l in range(3):
                real_temp[:, :, l] = real_rgb
            real_rgb = real_temp.copy()

        '对读入的RGB图像进行下取样,并对低分辨率图像进行bicubic插值'
        lr_rgb = imresize(real_rgb, 1 / scale, interp='bicubic')
        bicubic_rgb = imresize(lr_rgb, (img_width, img_height),
                               interp='bicubic')

        if mode == 'ycrcb':
            real_ycbcr = rgb2ycbcr(real_rgb)
            bicubic_ycbcr = rgb2ycbcr(bicubic_rgb)
            lr_test = np.empty(
                (1, bicubic_ycbcr.shape[0], bicubic_ycbcr.shape[1], 1))
            lr_test[0, :, :, 0] = bicubic_ycbcr[:, :, 0]
        elif mode == 'rgb':
            lr_test = np.empty((1, bicubic_rgb.shape[0], bicubic_rgb.shape[1],
                                bicubic_rgb.shape[2]))
            lr_test[0, :, :, :] = bicubic_rgb

        '读取训练后的SR模型'
        SRDenseNet_Test = SRModel(img_width=lr_test.shape[1],
                                  img_height=lr_test.shape[2],
                                  mode=mode)
        SRDenseNet_Test.build_model(load_weights=True)

        if normalized: lr_test /= 255

        sr_test = SRDenseNet_Test.model.predict(lr_test)

        if normalized: sr_test = sr_test * 255

        if mode == 'ycrcb':
            SR_ycrcb = np.empty((sr_test.shape[1], sr_test.shape[2], 3))
            SR_ycrcb[:, :, 0] = sr_test[0, :, :, 0]
            SR_ycrcb[:, :, 1] = bicubic_ycbcr[:, :, 1]
            SR_ycrcb[:, :, 2] = bicubic_ycbcr[:, :, 2]
            SR_rgb = ycbcr2rgb(SR_ycrcb)
        elif mode == 'rgb':
            SR_rgb = sr_test[0]
        SR_rgb = np.clip(SR_rgb, 0, 255)

        '保存验证集中的图片'
        real_img = np.clip(real_rgb, 0, 255).astype('uint8')
        bicubic_img = np.clip(bicubic_rgb, 0, 255).astype('uint8')
        output_image = np.clip(SR_rgb, 0, 255).astype('uint8')

        PSNR_bicubic = psnr(real_img.astype('float'),
                            bicubic_img.astype('float'))
        PSNR_generated = psnr(real_img.astype('float'),
                              output_image.astype('float'))

        real_ycbcr = rgb2ycbcr(real_img)
        bicubic_ycbcr = rgb2ycbcr(bicubic_img)
        output_ycbcr = rgb2ycbcr(output_image)
        SSIM_bicubic = compute_ssim(real_ycbcr[:, :, 0], bicubic_ycbcr[:, :,
                                                                       0])
        SSIM_generated = compute_ssim(real_ycbcr[:, :, 0], output_ycbcr[:, :,
                                                                        0])

        print(
            '%s_PSNR/SSIM:  bicubic    %0.2f/%0.2f ; SRDenseNet   %0.2f/%0.2f'
            % (file[0:len(file) - 4], PSNR_bicubic, SSIM_bicubic,
               PSNR_generated, SSIM_generated))

        if not os.path.exists(base_test_images + prefix):
            os.makedirs(base_test_images + prefix)
        real_path = base_test_images + prefix + '/' + prefix + "_%s_real.png" % (
            file[0:len(file) - 4])
        bicubic_img_path = base_test_images + prefix + '/' + prefix + "_%s_bicubic_%0.2f|%0.2f.png" % (
            file[0:len(file) - 4], PSNR_bicubic, SSIM_bicubic)
        generated_path = base_test_images + prefix + '/' + prefix + "_%s_generated_%0.2f|%0.2f.png" % (
            file[0:len(file) - 4], PSNR_generated, SSIM_generated)

        imsave(real_path, real_img)
        imsave(bicubic_img_path, bicubic_img)
        imsave(generated_path, output_image)

        total_psnr_bicubic += PSNR_bicubic
        total_psnr_generated += PSNR_generated
        total_ssim_bicubic += SSIM_bicubic
        total_ssim_generated += SSIM_generated

    l = len(os.listdir(pic_path))
    print(
        'Average_PSNR/SSIM:  bicubic    %0.2f/%0.2f ; VDSR_new   %0.2f/%0.2f' %
        (total_psnr_bicubic / l, total_ssim_bicubic / l,
         total_psnr_generated / l, total_ssim_generated / l))
    def run_epoch(self, epoch, dataloader, logimage=False, isTrain=True):
        # For details see training.
        psnr_value = 0
        ssim_value = 0
        loss_value = 0
        if not isTrain:
            valid_images = []
        for index, all_data in enumerate(dataloader, 0):
            self.optimizer.zero_grad()
            (
                Ft_p,
                I0,
                IFrame,
                I1,
                g_I0_F_t_0,
                g_I1_F_t_1,
                FlowBackWarp_I0_F_1_0,
                FlowBackWarp_I1_F_0_1,
                F_1_0,
                F_0_1,
            ) = self.slomo(all_data, pred_only=False, isTrain=isTrain)
            if (not isTrain) and logimage:
                if index % self.args.logimagefreq == 0:
                    valid_images.append(
                        255.0
                        * I0.cpu()[0]
                        .resize_(1, 1, self.args.data_h, self.args.data_w)
                        .repeat(1, 3, 1, 1)
                    )
                    valid_images.append(
                        255.0
                        * IFrame.cpu()[0]
                        .resize_(1, 1, self.args.data_h, self.args.data_w)
                        .repeat(1, 3, 1, 1)
                    )
                    valid_images.append(
                        255.0
                        * I1.cpu()[0]
                        .resize_(1, 1, self.args.data_h, self.args.data_w)
                        .repeat(1, 3, 1, 1)
                    )
                    valid_images.append(
                        255.0
                        * Ft_p.cpu()[0]
                        .resize_(1, 1, self.args.data_h, self.args.data_w)
                        .repeat(1, 3, 1, 1)
                    )
            # loss
            loss = self.supervisedloss(
                Ft_p,
                IFrame,
                I0,
                I1,
                g_I0_F_t_0,
                g_I1_F_t_1,
                FlowBackWarp_I0_F_1_0,
                FlowBackWarp_I1_F_0_1,
                F_1_0,
                F_0_1,
            )
            if isTrain:
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()

            loss_value += loss.item()

            # metrics
            psnr_value += psnr(Ft_p, IFrame, outputTensor=False)
            ssim_value += ssim(Ft_p, IFrame, outputTensor=False)

        name_loss = "TrainLoss" if isTrain else "ValLoss"
        itr = int(index + epoch * (len(dataloader)))
        if self.comet_exp is not None:
            self.comet_exp.log_metric(
                "PSNR", psnr_value / len(dataloader), step=itr, epoch=epoch
            )
            self.comet_exp.log_metric(
                "SSIM", ssim_value / len(dataloader), step=itr, epoch=epoch
            )
            self.comet_exp.log_metric(
                name_loss, loss_value / len(dataloader), step=itr, epoch=epoch
            )
            if logimage:
                upload_images(
                    valid_images,
                    epoch,
                    exp=self.comet_exp,
                    im_per_row=4,
                    rows_per_log=int(len(valid_images) / 4),
                )
        print(
            " Loss: %0.6f  Iterations: %4d/%4d  ValPSNR: %0.4f  ValSSIM: %0.4f "
            % (
                loss_value / len(dataloader),
                index,
                len(dataloader),
                psnr_value / len(dataloader),
                ssim_value / len(dataloader),
            )
        )
        return (
            (psnr_value / len(dataloader)),
            (ssim_value / len(dataloader)),
            (loss_value / len(dataloader)),
        )
Exemple #10
0
    def _train_model(self,
                     image_dir,
                     nb_images=50000,
                     nb_epochs=20,
                     pre_train=False,
                     load_generative_weights=False,
                     load_discriminator_weights=False,
                     save_loss=True):

        assert self.img_width >= 16, "Minimum image width must be at least 16"
        assert self.img_height >= 16, "Minimum image height must be at least 16"

        if not pre_train:
            if load_generative_weights:
                self.generative_model_.load_weights(
                    self.generative_network.sr_weights_path)

            if load_discriminator_weights:
                self.discriminative_network.load_gan_weights(self.srgan_model_)

        datagen = ImageDataGenerator(rescale=1. / 255)
        img_width = self.img_width * 4
        img_height = self.img_height * 4

        early_stop = False
        iteration = 0
        prev_improvement = -1

        if save_loss:
            if pre_train:
                loss_history = {
                    'generator_loss': [],
                    'val_psnr': [],
                }
            else:
                loss_history = {
                    'discriminator_loss': [],
                    'generator_loss': [],
                    'val_psnr': [],
                }

        y_vgg_dummy = np.zeros((self.batch_size * 2, 3, img_width // 32,
                                img_height // 32))  # 5 Max Pools = 2 ** 5 = 32

        if not pre_train:
            y_gan = [0] * self.batch_size + [1] * self.batch_size
            y_gan = np.asarray(y_gan, dtype=np.float32).reshape(-1, 1)

        print("Training SRGAN network")
        for i in range(nb_epochs):
            print()
            print("Epoch : %d" % (i + 1))

            for x in datagen.flow_from_directory(image_dir,
                                                 class_mode=None,
                                                 batch_size=self.batch_size,
                                                 target_size=(img_width,
                                                              img_height)):
                try:
                    t1 = time.time()

                    if not pre_train:
                        x_vgg = x.copy() * 255  # VGG input [0 - 255 scale]

                    # resize images
                    x_temp = x.copy()
                    x_temp = x_temp.transpose((0, 2, 3, 1))

                    x_generator = np.empty(
                        (self.batch_size, self.img_width, self.img_height, 3))

                    for j in range(self.batch_size):
                        img = gaussian_filter(x_temp[j], sigma=0.5)
                        img = imresize(img, (self.img_width, self.img_height))
                        x_generator[j, :, :, :] = img

                    x_generator = x_generator.transpose((0, 3, 1, 2))

                    if iteration % 50 == 0 and iteration != 0:
                        print("Validation image..")
                        output_image_batch = self.generative_network.get_generator_output(
                            x_generator, self.srgan_model_)
                        #output_image_batch = output_image_batch[0]

                        average_psnr = 0.0
                        for x_i in range(self.batch_size):
                            average_psnr += psnr(
                                x[x_i],
                                np.clip(output_image_batch[x_i], 0, 255) /
                                255.)

                        average_psnr /= self.batch_size

                        if save_loss:
                            loss_history['val_psnr'].append(average_psnr)

                        iteration += self.batch_size
                        t2 = time.time()

                        print(
                            "Time required : %0.2f. Average validation PSNR over %d samples = %0.2f"
                            % (t2 - t1, self.batch_size, average_psnr))

                        for x_i in range(self.batch_size):
                            real_path = "val_images/epoch_%d_iteration_%d_num_%d_real_.png" % (
                                i + 1, iteration, x_i + 1)
                            generated_path = "val_images/epoch_%d_iteration_%d_num_%d_generated.png" % (
                                i + 1, iteration, x_i + 1)

                            val_x = x[x_i].copy() * 255.
                            val_x = val_x.transpose((1, 2, 0))
                            val_x = np.clip(val_x, 0, 255).astype('uint8')

                            # print('min = ', np.min(output_image_batch[x_i]))
                            # print('max = ', np.max(output_image_batch[x_i]))
                            # print('mean = ', np.mean(output_image_batch[x_i]))

                            output_image = output_image_batch[x_i]
                            output_image = output_image.transpose((1, 2, 0))
                            output_image = np.clip(output_image, 0,
                                                   255).astype('uint8')

                            imsave(real_path, val_x)
                            imsave(generated_path, output_image)
                        '''
                        Don't train of validation images for now.

                        Note that if nb_epochs > 1, there is a chance that
                        validation images may be used for training purposes as well.

                        In that case, this isn't strictly a validation measure, instead of
                        just a check to see what the network has learned.
                        '''
                        continue

                    if pre_train:
                        # Train only generator + vgg network

                        # Use custom bypass_fit to bypass the check for same input and output batch size
                        hist = bypass_fit(self.srgan_model_,
                                          [x_generator, x * 255],
                                          y_vgg_dummy,
                                          batch_size=self.batch_size,
                                          nb_epoch=1,
                                          verbose=0)
                        sr_loss = hist.history['loss'][0]

                        if save_loss:
                            loss_history['generator_loss'].append(sr_loss)

                        if prev_improvement == -1:
                            prev_improvement = sr_loss

                        improvement = (prev_improvement -
                                       sr_loss) / prev_improvement * 100
                        prev_improvement = sr_loss

                        iteration += self.batch_size
                        t2 = time.time()

                        print(
                            "Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | "
                            "Generative Loss : %0.3f" %
                            (iteration, nb_images, improvement, t2 - t1,
                             sr_loss))

                    else:

                        # Train only discriminator, disable training of srgan
                        self.discriminative_network.set_trainable(
                            self.srgan_model_, value=True)
                        self.generative_network.set_trainable(
                            self.srgan_model_, value=False)

                        # Use custom bypass_fit to bypass the check for same input and output batch size
                        hist = bypass_fit(self.srgan_model_,
                                          [x_generator, x * 255, x_vgg],
                                          [y_gan, y_vgg_dummy],
                                          batch_size=self.batch_size,
                                          nb_epoch=1,
                                          verbose=0)

                        discriminator_loss = hist.history['loss'][0]

                        # Train only generator, disable training of discriminator
                        self.discriminative_network.set_trainable(
                            self.srgan_model_, value=False)
                        self.generative_network.set_trainable(
                            self.srgan_model_, value=True)

                        # Use custom bypass_fit to bypass the check for same input and output batch size
                        hist = bypass_fit(self.srgan_model_,
                                          [x_generator, x * 255, x_vgg],
                                          [y_gan, y_vgg_dummy],
                                          batch_size=self.batch_size,
                                          nb_epoch=1,
                                          verbose=0)

                        generative_loss = hist.history['loss'][0]

                        if save_loss:
                            loss_history['discriminator_loss'].append(
                                discriminator_loss)
                            loss_history['generator_loss'].append(
                                generative_loss)

                        if prev_improvement == -1:
                            prev_improvement = discriminator_loss

                        improvement = (prev_improvement - discriminator_loss
                                       ) / prev_improvement * 100
                        prev_improvement = discriminator_loss

                        iteration += self.batch_size
                        t2 = time.time()
                        print(
                            "Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | "
                            "Discriminator Loss : %0.3f | Generative Loss : %0.3f"
                            % (iteration, nb_images, improvement, t2 - t1,
                               discriminator_loss, generative_loss))

                    if iteration % 1000 == 0 and iteration != 0:
                        print("Saving model weights.")
                        # Save predictive (SR network) weights
                        self.generative_model_.save_weights(
                            self.generative_network.sr_weights_path,
                            overwrite=True)

                        if not pre_train:
                            # Save GAN (discriminative network) weights
                            self.discriminative_network.save_gan_weights(
                                self.srgan_model_)

                        if save_loss:
                            print("Saving loss history")

                            if pre_train:
                                with open('pretrain losses.json', 'w') as f:
                                    json.dump(loss_history, f)

                            else:
                                with open('fulltrain losses.json', 'w') as f:
                                    json.dump(loss_history, f)

                            print("Saved loss history")

                    if iteration >= nb_images:
                        break

                except KeyboardInterrupt:
                    print("Keyboard interrupt detected. Stopping early.")
                    early_stop = True
                    break

            iteration = 0

            if early_stop:
                break

        print("Finished training SRGAN network. Saving model weights.")

        # Save predictive (SR network) weights
        self.generative_model_.save_weights(
            self.generative_network.sr_weights_path)

        if not pre_train:
            # Save GAN (discriminative network) weights
            self.discriminative_network.save_gan_weights(self.srgan_model_)

        print("Weights saved in 'weights' directory")

        if save_loss:
            print("Saving loss history")

            if pre_train:
                with open('pretrain losses.json', 'w') as f:
                    json.dump(loss_history, f)

            else:
                with open('fulltrain losses.json', 'w') as f:
                    json.dump(loss_history, f)

            print("Saved loss history")
    def _train_model(self, image_dir, nb_images=80000, nb_epochs=10, pre_train_srgan=False,
                     pre_train_discriminator=False, load_generative_weights=False, load_discriminator_weights=False,
                     save_loss=True, disc_train_flip=0.1):

        assert self.img_width >= 16, "Minimum image width must be at least 16"
        assert self.img_height >= 16, "Minimum image height must be at least 16"

        if load_generative_weights:
            try:
                self.generative_model_.load_weights(self.generative_network.sr_weights_path)
                print("Generator weights loaded.")
            except:
                print("Could not load generator weights.")

        if load_discriminator_weights:
            try:
                self.discriminative_network.load_gan_weights(self.srgan_model_)
                print("Discriminator weights loaded.")
            except:
                print("Could not load discriminator weights.")

        datagen = ImageDataGenerator(rescale=1. / 255)
        img_width = self.img_width * 4
        img_height = self.img_height * 4

        early_stop = False
        iteration = 0
        prev_improvement = -1

        if save_loss:
            if pre_train_srgan:
                loss_history = {'generator_loss' : [],
                                'val_psnr' : [], }
            elif pre_train_discriminator:
                loss_history = {'discriminator_loss' : [],
                                'discriminator_acc' : [], }
            else:
                loss_history = {'discriminator_loss' : [],
                                'discriminator_acc' : [],
                                'generator_loss' : [],
                                'val_psnr': [], }

        y_vgg_dummy = np.zeros((self.batch_size * 2, 3, img_width // 32, img_height // 32)) # 5 Max Pools = 2 ** 5 = 32

        print("Training SRGAN network")
        for i in range(nb_epochs):
            print()
            print("Epoch : %d" % (i + 1))

            for x in datagen.flow_from_directory(image_dir, class_mode=None, batch_size=self.batch_size,
                                                 target_size=(img_width, img_height)):
                try:
                    t1 = time.time()

                    if not pre_train_srgan and not pre_train_discriminator:
                        x_vgg = x.copy() * 255 # VGG input [0 - 255 scale]

                    # resize images
                    x_temp = x.copy()
                    x_temp = x_temp.transpose((0, 2, 3, 1))

                    x_generator = np.empty((self.batch_size, self.img_width, self.img_height, 3))

                    for j in range(self.batch_size):
                        img = gaussian_filter(x_temp[j], sigma=0.1)
                        img = imresize(img, (self.img_width, self.img_height), interp='bicubic')
                        x_generator[j, :, :, :] = img

                    x_generator = x_generator.transpose((0, 3, 1, 2))

                    if iteration % 50 == 0 and iteration != 0 and not pre_train_discriminator:
                        print("Validation image..")
                        output_image_batch = self.generative_network.get_generator_output(x_generator,
                                                                                          self.srgan_model_)
                        if type(output_image_batch) == list:
                            output_image_batch = output_image_batch[0]

                        mean_axis = (0, 2, 3) if K.image_dim_ordering() == 'th' else (0, 1, 2)

                        average_psnr = 0.0

                        print('gen img mean :', np.mean(output_image_batch / 255., axis=mean_axis))
                        print('val img mean :', np.mean(x, axis=mean_axis))

                        for x_i in range(self.batch_size):
                            average_psnr += psnr(x[x_i], np.clip(output_image_batch[x_i], 0, 255) / 255.)

                        average_psnr /= self.batch_size

                        if save_loss:
                            loss_history['val_psnr'].append(average_psnr)

                        iteration += self.batch_size
                        t2 = time.time()

                        print("Time required : %0.2f. Average validation PSNR over %d samples = %0.2f" %
                              (t2 - t1, self.batch_size, average_psnr))

                        for x_i in range(self.batch_size):
                            real_path = "val_images/epoch_%d_iteration_%d_num_%d_real_.png" % (i + 1, iteration, x_i + 1)
                            generated_path = "val_images/epoch_%d_iteration_%d_num_%d_generated.png" % (i + 1,
                                                                                                        iteration,
                                                                                                        x_i + 1)

                            val_x = x[x_i].copy() * 255.
                            val_x = val_x.transpose((1, 2, 0))
                            val_x = np.clip(val_x, 0, 255).astype('uint8')

                            output_image = output_image_batch[x_i]
                            output_image = output_image.transpose((1, 2, 0))
                            output_image = np.clip(output_image, 0, 255).astype('uint8')

                            imsave(real_path, val_x)
                            imsave(generated_path, output_image)

                        '''
                        Don't train of validation images for now.

                        Note that if nb_epochs > 1, there is a chance that
                        validation images may be used for training purposes as well.

                        In that case, this isn't strictly a validation measure, instead of
                        just a check to see what the network has learned.
                        '''
                        continue

                    if pre_train_srgan:
                        # Train only generator + vgg network

                        # Use custom bypass_fit to bypass the check for same input and output batch size
                        hist = bypass_fit(self.srgan_model_, [x_generator, x * 255], y_vgg_dummy,
                                                     batch_size=self.batch_size, nb_epoch=1, verbose=0)
                        sr_loss = hist.history['loss'][0]

                        if save_loss:
                            loss_history['generator_loss'].extend(hist.history['loss'])

                        if prev_improvement == -1:
                            prev_improvement = sr_loss

                        improvement = (prev_improvement - sr_loss) / prev_improvement * 100
                        prev_improvement = sr_loss

                        iteration += self.batch_size
                        t2 = time.time()

                        print("Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | "
                              "Generative Loss : %0.2f" % (iteration, nb_images, improvement, t2 - t1, sr_loss))
                    elif pre_train_discriminator:
                        # Train only discriminator
                        X_pred = self.generative_model_.predict(x_generator, self.batch_size)

                        X = np.concatenate((X_pred, x * 255))

                        # Using soft and noisy labels
                        if np.random.uniform() > disc_train_flip:
                            # give correct classifications
                            y_gan = [0] * self.batch_size + [1] * self.batch_size
                        else:
                            # give wrong classifications (noisy labels)
                            y_gan = [1] * self.batch_size + [0] * self.batch_size

                        y_gan = np.asarray(y_gan, dtype=np.int).reshape(-1, 1)
                        y_gan = to_categorical(y_gan, nb_classes=2)
                        y_gan = smooth_gan_labels(y_gan)

                        hist = self.discriminative_model_.fit(X, y_gan, batch_size=self.batch_size,
                                                              nb_epoch=1, verbose=0)

                        discriminator_loss = hist.history['loss'][-1]
                        discriminator_acc = hist.history['acc'][-1]

                        if save_loss:
                            loss_history['discriminator_loss'].extend(hist.history['loss'])
                            loss_history['discriminator_acc'].extend(hist.history['acc'])

                        if prev_improvement == -1:
                            prev_improvement = discriminator_loss

                        improvement = (prev_improvement - discriminator_loss) / prev_improvement * 100
                        prev_improvement = discriminator_loss

                        iteration += self.batch_size
                        t2 = time.time()

                        print("Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | "
                            "Discriminator Loss / Acc : %0.4f / %0.2f" % (iteration, nb_images,
                                                            improvement, t2 - t1,
                                                            discriminator_loss, discriminator_acc))

                    else:
                        # Train only discriminator, disable training of srgan
                        self.discriminative_network.set_trainable(self.srgan_model_, value=True)
                        self.generative_network.set_trainable(self.srgan_model_, value=False)

                        # Use custom bypass_fit to bypass the check for same input and output batch size
                        # hist = bypass_fit(self.srgan_model_, [x_generator, x * 255, x_vgg],
                        #                          [y_gan, y_vgg_dummy],
                        #                          batch_size=self.batch_size, nb_epoch=1, verbose=0)

                        X_pred = self.generative_model_.predict(x_generator, self.batch_size)

                        X = np.concatenate((X_pred, x * 255))

                        # Using soft and noisy labels
                        if np.random.uniform() > disc_train_flip:
                            # give correct classifications
                            y_gan = [0] * self.batch_size + [1] * self.batch_size
                        else:
                            # give wrong classifications (noisy labels)
                            y_gan = [1] * self.batch_size + [0] * self.batch_size

                        y_gan = np.asarray(y_gan, dtype=np.int).reshape(-1, 1)
                        y_gan = to_categorical(y_gan, nb_classes=2)
                        y_gan = smooth_gan_labels(y_gan)

                        hist1 = self.discriminative_model_.fit(X, y_gan, verbose=0, batch_size=self.batch_size,
                                                              nb_epoch=1)

                        discriminator_loss = hist1.history['loss'][-1]

                        # Train only generator, disable training of discriminator
                        self.discriminative_network.set_trainable(self.srgan_model_, value=False)
                        self.generative_network.set_trainable(self.srgan_model_, value=True)

                        # Using soft labels
                        y_model = [1] * self.batch_size
                        y_model = np.asarray(y_model, dtype=np.int).reshape(-1, 1)
                        y_model = to_categorical(y_model, nb_classes=2)
                        y_model = smooth_gan_labels(y_model)

                        # Use custom bypass_fit to bypass the check for same input and output batch size
                        hist2 = bypass_fit(self.srgan_model_, [x_generator, x, x_vgg], [y_model, y_vgg_dummy],
                                           batch_size=self.batch_size, nb_epoch=1, verbose=0)

                        generative_loss = hist2.history['loss'][0]

                        if save_loss:
                            loss_history['discriminator_loss'].extend(hist1.history['loss'])
                            loss_history['discriminator_acc'].extend(hist1.history['acc'])
                            loss_history['generator_loss'].extend(hist2.history['loss'])

                        if prev_improvement == -1:
                            prev_improvement = discriminator_loss

                        improvement = (prev_improvement - discriminator_loss) / prev_improvement * 100
                        prev_improvement = discriminator_loss

                        iteration += self.batch_size
                        t2 = time.time()
                        print("Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | "
                              "Discriminator Loss : %0.3f | Generative Loss : %0.3f" %
                              (iteration, nb_images, improvement, t2 - t1, discriminator_loss, generative_loss))

                    if iteration % 1000 == 0 and iteration != 0:
                        print("Saving model weights.")
                        # Save predictive (SR network) weights
                        self._save_model_weights(pre_train_srgan, pre_train_discriminator)
                        self._save_loss_history(loss_history, pre_train_srgan, pre_train_discriminator, save_loss)

                    if iteration >= nb_images:
                        break

                except KeyboardInterrupt:
                    print("Keyboard interrupt detected. Stopping early.")
                    early_stop = True
                    break

            iteration = 0

            if early_stop:
                break

        print("Finished training SRGAN network. Saving model weights.")
        # Save predictive (SR network) weights
        self._save_model_weights(pre_train_srgan, pre_train_discriminator)
        self._save_loss_history(loss_history, pre_train_srgan, pre_train_discriminator, save_loss)
Exemple #12
0
    def train_model(self, image_dir, nb_images=50000, nb_epochs=1):
        datagen = ImageDataGenerator(rescale=1. / 255)
        img_width = self.img_width * 4
        img_height = self.img_height * 4

        early_stop = False
        iteration = 0
        prev_improvement = -1

        print("Training SR ResNet network")
        for i in range(nb_epochs):
            print()
            print("Epoch : %d" % (i + 1))

            for x in datagen.flow_from_directory(image_dir,
                                                 class_mode=None,
                                                 batch_size=self.batch_size,
                                                 target_size=(img_width,
                                                              img_height)):

                try:
                    t1 = time.time()

                    # resize images
                    x_temp = x.copy()
                    # x_temp = x_temp.transpose((0, 2, 3, 1))

                    x_generator = np.empty(
                        (self.batch_size, self.img_width, self.img_height, 3))

                    for j in range(self.batch_size):
                        img = gaussian_filter(x_temp[j], sigma=0.5)
                        img = imresize(img, (self.img_width, self.img_height))
                        x_generator[j, :, :, :] = img

                    # x_generator = x_generator.transpose((0, 3, 1, 2))

                    if iteration % 50 == 0 and iteration != 0:
                        print("Random Validation image..")
                        output_image_batch = self.model.predict_on_batch(
                            x_generator)

                        print("Pred Max / Min: %0.2f / %0.2f" %
                              (output_image_batch.max(),
                               output_image_batch.min()))

                        average_psnr = 0.0
                        for x_i in range(self.batch_size):
                            average_psnr += psnr(
                                x[x_i], output_image_batch[x_i] / 255.)

                        average_psnr /= self.batch_size

                        iteration += self.batch_size
                        t2 = time.time()

                        print(
                            "Time required : %0.2f. Average validation PSNR over %d samples = %0.2f"
                            % (t2 - t1, self.batch_size, average_psnr))

                        for x_i in range(self.batch_size):
                            real_path = base_val_images_path + "epoch_%d_iteration_%d_num_%d_real_.png" % \
                                                               (i + 1, iteration, x_i + 1)

                            generated_path = base_val_images_path + \
                                             "epoch_%d_iteration_%d_num_%d_generated.png" % (i + 1,
                                                                                            iteration,
                                                                                            x_i + 1)

                            val_x = x[x_i].copy() * 255.
                            #val_x = val_x.transpose((1, 2, 0))
                            val_x = np.clip(val_x, 0, 255).astype('uint8')

                            output_image = output_image_batch[x_i]
                            #output_image = output_image.transpose((1, 2, 0))
                            output_image = np.clip(output_image, 0,
                                                   255).astype('uint8')

                            imsave(real_path, val_x)
                            imsave(generated_path, output_image)
                        '''
                        Don't train of validation images for now.

                        Note that if nb_epochs > 1, there is a chance that
                        validation images may be used for training purposes as well.

                        In that case, this isn't strictly a validation measure, instead of
                        just a check to see what the network has learned.
                        '''
                        continue

                    hist = self.model.fit(x_generator,
                                          x * 255,
                                          batch_size=self.batch_size,
                                          epochs=1,
                                          verbose=0)
                    psnr_loss_val = hist.history['PSNRLoss'][0]

                    if prev_improvement == -1:
                        prev_improvement = psnr_loss_val

                    improvement = (prev_improvement -
                                   psnr_loss_val) / prev_improvement * 100
                    prev_improvement = psnr_loss_val

                    iteration += self.batch_size
                    t2 = time.time()

                    print(
                        "Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | "
                        "PSNR : %0.3f" % (iteration, nb_images, improvement,
                                          t2 - t1, psnr_loss_val))

                    if iteration % 1000 == 0 and iteration != 0:
                        print("Saving weights")
                        self.model.save_weights(self.weights_path,
                                                overwrite=True)

                    if iteration >= nb_images:
                        break

                except KeyboardInterrupt:
                    print("Keyboard interrupt detected. Stopping early.")
                    early_stop = True
                    break

            iteration = 0

            if early_stop:
                break

        print("Finished training SRGAN network. Saving model weights.")
    def train_model(self, image_dir, nb_images=50000, nb_epochs=1):
        datagen = ImageDataGenerator(rescale=1. / 255)
        img_width = self.img_width * 4
        img_height = self.img_height * 4

        early_stop = False
        iteration = 0
        prev_improvement = -1

        print("Training SR ResNet network")
        for i in range(nb_epochs):
            print()
            print("Epoch : %d" % (i + 1))

            for x in datagen.flow_from_directory(image_dir, class_mode=None, batch_size=self.batch_size,
                                                 target_size=(img_width, img_height)):

                try:
                    t1 = time.time()

                    # resize images
                    x_temp = x.copy()
                    x_temp = x_temp.transpose((0, 2, 3, 1))

                    x_generator = np.empty((self.batch_size, self.img_width, self.img_height, 3))

                    for j in range(self.batch_size):
                        img = gaussian_filter(x_temp[j], sigma=0.5)
                        img = imresize(img, (self.img_width, self.img_height))
                        x_generator[j, :, :, :] = img

                    x_generator = x_generator.transpose((0, 3, 1, 2))

                    if iteration % 50 == 0 and iteration != 0 :
                        print("Random Validation image..")
                        output_image_batch = self.model.predict_on_batch(x_generator)

                        print("Pred Max / Min: %0.2f / %0.2f" % (output_image_batch.max(),
                                                                 output_image_batch.min()))

                        average_psnr = 0.0
                        for x_i in range(self.batch_size):
                            average_psnr += psnr(x[x_i], output_image_batch[x_i] / 255.)

                        average_psnr /= self.batch_size

                        iteration += self.batch_size
                        t2 = time.time()

                        print("Time required : %0.2f. Average validation PSNR over %d samples = %0.2f" %
                              (t2 - t1, self.batch_size, average_psnr))

                        for x_i in range(self.batch_size):
                            real_path = base_val_images_path + "epoch_%d_iteration_%d_num_%d_real_.png" % \
                                                               (i + 1, iteration, x_i + 1)

                            generated_path = base_val_images_path + \
                                             "epoch_%d_iteration_%d_num_%d_generated.png" % (i + 1,
                                                                                            iteration,
                                                                                            x_i + 1)

                            val_x = x[x_i].copy() * 255.
                            val_x = val_x.transpose((1, 2, 0))
                            val_x = np.clip(val_x, 0, 255).astype('uint8')

                            output_image = output_image_batch[x_i]
                            output_image = output_image.transpose((1, 2, 0))
                            output_image = np.clip(output_image, 0, 255).astype('uint8')

                            imsave(real_path, val_x)
                            imsave(generated_path, output_image)

                        '''
                        Don't train of validation images for now.

                        Note that if nb_epochs > 1, there is a chance that
                        validation images may be used for training purposes as well.

                        In that case, this isn't strictly a validation measure, instead of
                        just a check to see what the network has learned.
                        '''
                        continue

                    hist = self.model.fit(x_generator, x * 255, batch_size=self.batch_size, nb_epoch=1, verbose=0)
                    psnr_loss_val = hist.history['PSNRLoss'][0]

                    if prev_improvement == -1:
                        prev_improvement = psnr_loss_val

                    improvement = (prev_improvement - psnr_loss_val) / prev_improvement * 100
                    prev_improvement = psnr_loss_val

                    iteration += self.batch_size
                    t2 = time.time()

                    print("Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | "
                          "PSNR : %0.3f" % (iteration, nb_images, improvement, t2 - t1, psnr_loss_val))

                    if iteration % 1000 == 0 and iteration != 0:
                        print("Saving weights")
                        self.model.save_weights(self.weights_path, overwrite=True)

                    if iteration >= nb_images:
                        break

                except KeyboardInterrupt:
                    print("Keyboard interrupt detected. Stopping early.")
                    early_stop = True
                    break

            iteration = 0

            if early_stop:
                break

        print("Finished training SRGAN network. Saving model weights.")