Пример #1
0
def init_net_module():
    """
    初始话网络模块
    :return: 生成器
    """
    # 模型保存目录
    model_save_dir = './model_save_res_block_256_vgg16'

    def load_saved_weight(g, d=None):
        """
        加载已训练好的权重
        :param g: 生成器
        :param d: 判别器
        :return:
        """
        # TODO: 这里需要做细化处理。判定文件是否存在。多个权重文件找到最新的权重文件
        g.load_weights(os.path.join(model_save_dir, 'generator_49_33.h5'))
        if d is None:
            return
        d.load_weights(os.path.join(model_save_dir, 'discriminator_49.h5'))

    # 构建网络模型
    global g
    g = generator_model()
    # 加载模型权重
    load_saved_weight(g)
Пример #2
0
def train(batch_size, epochs, critic_updates=5):
    """
    训练网络
    :param batch_size:
    :param epochs:
    :param critic_updates: 每个batch_size 中 Discriminator需要训练的次数
    :return:
    """
    # 加载数据
    data_loader = DataLoader(batch_size)

    # 构建网络模型
    g = generator_model()
    # g.summary()
    d = discriminator_model()
    d.summary()
    d_on_g = generator_containing_discriminator_multiple_outputs(g, d)

    # 保存模型结构--用于可视化
    g.save(os.path.join(model_save_dir, "generator.h5"))
    d.save(os.path.join(model_save_dir, "discriminator.h5"))
    d_on_g.save(os.path.join(model_save_dir, "d_on_g.h5"))

    # 编译网络模型
    d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    d.trainable = True
    d.compile(optimizer=d_opt, loss=wasserstein_loss)
    d.trainable = False
    loss = [perceptual_loss, wasserstein_loss]
    loss_weights = [100, 1]
    d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
    d.trainable = True

    # 设置discriminator的real目标和fake目标
    output_true_batch, output_false_batch = np.ones((batch_size, 1)), -np.ones(
        (batch_size, 1))
    # tensorboard_callback = TensorBoard(log_dir)

    # TODO: 可以在这里加入恢复权重,接力学习

    # 训练
    start = datetime.datetime.now()
    for epoch in tqdm.tqdm(range(epochs)):
        d_losses = []
        d_on_g_losses = []
        for index in range(data_loader.file_nums // batch_size):
            img_haze_batch, img_clear_batch = next(data_loader.train_generator)
            # 放缩到-1 - 1
            img_haze_batch = img_haze_batch / 127.5 - 1
            img_clear_batch = img_clear_batch / 127.5 - 1

            generated_images = g.predict(x=img_haze_batch,
                                         batch_size=batch_size)

            for _ in range(critic_updates):
                d_loss_real = d.train_on_batch(img_clear_batch,
                                               output_true_batch)
                d_loss_fake = d.train_on_batch(generated_images,
                                               output_false_batch)
                d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
                d_losses.append(d_loss)

            d.trainable = False

            d_on_g_loss = d_on_g.train_on_batch(
                img_haze_batch, [img_clear_batch, output_true_batch])
            d_on_g_losses.append(d_on_g_loss)

            d.trainable = True

            # print log
            print('d loss %f d_on_g loss %f' %
                  (d_loss, d_on_g_loss[1] + d_on_g_loss[2]))

            if index % 50 == 0:
                # Test
                img_haze_test, img_clear_test = next(
                    data_loader.test_generator)
                generated_images = g.predict(x=img_haze_test / 127.5 - 1,
                                             batch_size=batch_size)
                # 放缩为0-255
                generated_images = (generated_images + 1) * 127.5

                fig, axs = plt.subplots(batch_size, 3)
                for idx in range(batch_size):
                    axs[idx, 0].imshow((img_haze_test[idx].astype('uint8')))
                    axs[idx, 0].axis('off')
                    axs[idx, 0].set_title('haze')

                    axs[idx, 1].imshow((img_clear_test[idx].astype('uint8')))
                    axs[idx, 1].axis('off')
                    axs[idx, 1].set_title('origin')

                    axs[idx, 2].imshow(generated_images[idx].astype('uint8'))
                    axs[idx, 2].axis('off')
                    axs[idx, 2].set_title('dehazed')
                fig.savefig("./dehazed_result/image/dehazed/%d-%d.jpg" %
                            (epoch, index))

        now = datetime.datetime.now()
        print(np.mean(d_losses), np.mean(d_on_g_losses),
              'spend time %s' % (now - start))
        # 保存所有权重
        save_all_weights(d, g, epoch, int(np.mean(d_on_g_losses)))
Пример #3
0
def test():
    """
    测试函数。计算指标
    :return:
    """
    # 构建网络模型
    g = generator_model('test')
    # 加载模型权重
    load_saved_weight(g)

    ##########################################
    # 测试集新代码。直接从jpg文件中读取,避免npy转
    #  case 1: 合成雾图去雾 生成去雾后的结果,并计算psnr,ssim
    #  case 2: 真实雾图去雾 生成去雾后的结果
    ##########################################
    def load_img_files(dir):
        """
        加载dir目录下的所有jpg后缀文件
        :param dir:
        :return: array数组
        """
        file_paths = glob.glob(os.path.join(dir, '*.jpg'))

        imgs = []
        for idx, file_path in enumerate(file_paths):
            imgs.append(np.array(Image.open(file_path).convert('RGB')))
        return np.array(imgs)

    def predict(g, haze_imgs):
        """
        输入haze_imgs,用g预测clear_imgs。
        之所以用这个函数,而不直接用g.predict,是为了适应haze_imgs中的img具有不同size的情况
        :param g
        :param haze_imgs: 雾图 size bound是 0 - 255
        :return: clear_imgs (每个clear_img可能具有不同的shape) size bound 是 0 -255
        """
        clear_imgs = []
        for haze_img in haze_imgs:
            haze_img = np.expand_dims(haze_img, axis=0)
            clear_img = g.predict(haze_img / 127.5 - 1)[0]
            clear_imgs.append((clear_img + 1) * 127.5)
        return np.array(clear_imgs)

    mode = "real"  # synthesis or real
    # 清晰图目录
    clear_imgs_dir = ''
    # 雾图目录
    haze_imgs_dir = '../test_imgs'
    # 去雾结果保存目录
    dehaze_imgs_dir = '../test_imgs'
    if mode == "synthesis":
        clear_imgs = load_img_files(clear_imgs_dir)
        haze_imgs = load_img_files(haze_imgs_dir)

        # 去雾
        generated_imgs = predict(g, haze_imgs)

        # 初始化指标
        PSNR = 0
        SSIM = 0

        for idx, generated_img in enumerate(generated_imgs):
            dehazed_img = Image.fromarray(generated_img.astype('uint8'))
            dehazed_img.save(
                os.path.join(dehaze_imgs_dir, "%03d.jpg" % (idx + 1)))
            PSNR = PSNR + compare_psnr(clear_imgs[idx].astype('uint8'),
                                       generated_img.astype('uint8'))
            SSIM = SSIM + ssim(clear_imgs[idx].astype('uint8'),
                               generated_img.astype('uint8'),
                               multichannel=True)
        # 计算平均值
        PSNR = PSNR / len(generated_imgs)
        SSIM = SSIM / len(generated_imgs)
        print('PSNR', PSNR)
        print('SSIM', SSIM)
    elif mode == 'real':
        haze_imgs = load_img_files(haze_imgs_dir)
        # 去雾
        generated_imgs = predict(g, haze_imgs)

        for idx, generated_img in enumerate(generated_imgs):
            dehazed_img = Image.fromarray(generated_img.astype('uint8'))
            dehazed_img.save(
                os.path.join(dehaze_imgs_dir, "%03d.jpg" % (idx + 1)))
Пример #4
0
def test():
    """
    测试函数。计算指标
    :return:
    """
    # 构建网络模型
    g = generator_model()
    # 加载模型权重
    load_saved_weight(g)

    ##########################################
    # 测试集新代码。直接从jpg文件中读取,避免npy转
    #  case 1: 合成雾图去雾 生成去雾后的结果,并计算psnr,ssim
    #  case 2: 真实雾图去雾 生成去雾后的结果
    ##########################################
    def load_img_files(dir):
        """
        加载dir目录下的所有jpg后缀文件
        :param dir:
        :return: array数组
        """
        file_paths = glob.glob(os.path.join(dir, '*.jpg'))
        file_num = len(file_paths)

        imgs = np.zeros((file_num, img_height, img_width, 3))
        for idx, file_path in enumerate(file_paths):
            imgs[idx] = np.array(Image.open(file_path).convert('RGB'))
        return imgs

    mode = "synthesis"  # synthesis or real
    # 清晰图目录
    clear_imgs_dir = 'D:/Projects/Dehaze/其他论文去雾代码/HazeRD合成测试集/clear'
    # 雾图目录
    haze_imgs_dir = 'D:/Projects/Dehaze/其他论文去雾代码/HazeRD合成测试集/haze'
    # 去雾结果保存目录
    dehaze_imgs_dir = 'D:/Projects/Dehaze/自己论文去雾代码/DeBulrGanToDeHaze/script/HazeRD合成雾图去雾结果'
    if mode == "synthesis":
        clear_imgs = load_img_files(clear_imgs_dir)
        haze_imgs = load_img_files(haze_imgs_dir)

        # 去雾
        generated_imgs = g.predict(haze_imgs / 127.5 - 1)
        generated_imgs = (generated_imgs + 1) * 127.5

        # 初始化指标
        PSNR = 0
        SSIM = 0

        for idx, generated_img in enumerate(generated_imgs):
            dehazed_img = Image.fromarray(generated_img.astype('uint8'))
            dehazed_img.save(os.path.join(dehaze_imgs_dir, "%03d.jpg" % (idx + 1)))
            PSNR = PSNR + compare_psnr(clear_imgs[idx].astype('uint8'), generated_img.astype('uint8'))
            SSIM = SSIM + ssim(clear_imgs[idx].astype('uint8'), generated_img.astype('uint8'), multichannel=True)
        # 计算平均值
        PSNR = PSNR / len(generated_imgs)
        SSIM = SSIM / len(generated_imgs)
        print('PSNR',PSNR)
        print('SSIM',SSIM)
    elif mode == 'real':
        haze_imgs = load_img_files(haze_imgs_dir)
        # 去雾
        generated_imgs = g.predict(haze_imgs / 127.5 - 1)
        generated_imgs = (generated_imgs + 1) * 127.5

        for idx, generated_img in enumerate(generated_imgs):
            dehazed_img = Image.fromarray(generated_img.astype('uint8'))
            dehazed_img.save(os.path.join(dehaze_imgs_dir, "%03d.jpg" % (idx + 1)))