def init_net_module(): """ 初始话网络模块 :return: 生成器 """ # 模型保存目录 model_save_dir = './model_save_res_block_256_vgg16' def load_saved_weight(g, d=None): """ 加载已训练好的权重 :param g: 生成器 :param d: 判别器 :return: """ # TODO: 这里需要做细化处理。判定文件是否存在。多个权重文件找到最新的权重文件 g.load_weights(os.path.join(model_save_dir, 'generator_49_33.h5')) if d is None: return d.load_weights(os.path.join(model_save_dir, 'discriminator_49.h5')) # 构建网络模型 global g g = generator_model() # 加载模型权重 load_saved_weight(g)
def train(batch_size, epochs, critic_updates=5): """ 训练网络 :param batch_size: :param epochs: :param critic_updates: 每个batch_size 中 Discriminator需要训练的次数 :return: """ # 加载数据 data_loader = DataLoader(batch_size) # 构建网络模型 g = generator_model() # g.summary() d = discriminator_model() d.summary() d_on_g = generator_containing_discriminator_multiple_outputs(g, d) # 保存模型结构--用于可视化 g.save(os.path.join(model_save_dir, "generator.h5")) d.save(os.path.join(model_save_dir, "discriminator.h5")) d_on_g.save(os.path.join(model_save_dir, "d_on_g.h5")) # 编译网络模型 d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) d.trainable = True d.compile(optimizer=d_opt, loss=wasserstein_loss) d.trainable = False loss = [perceptual_loss, wasserstein_loss] loss_weights = [100, 1] d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights) d.trainable = True # 设置discriminator的real目标和fake目标 output_true_batch, output_false_batch = np.ones((batch_size, 1)), -np.ones( (batch_size, 1)) # tensorboard_callback = TensorBoard(log_dir) # TODO: 可以在这里加入恢复权重,接力学习 # 训练 start = datetime.datetime.now() for epoch in tqdm.tqdm(range(epochs)): d_losses = [] d_on_g_losses = [] for index in range(data_loader.file_nums // batch_size): img_haze_batch, img_clear_batch = next(data_loader.train_generator) # 放缩到-1 - 1 img_haze_batch = img_haze_batch / 127.5 - 1 img_clear_batch = img_clear_batch / 127.5 - 1 generated_images = g.predict(x=img_haze_batch, batch_size=batch_size) for _ in range(critic_updates): d_loss_real = d.train_on_batch(img_clear_batch, output_true_batch) d_loss_fake = d.train_on_batch(generated_images, output_false_batch) d_loss = 0.5 * np.add(d_loss_fake, d_loss_real) d_losses.append(d_loss) d.trainable = False d_on_g_loss = d_on_g.train_on_batch( img_haze_batch, [img_clear_batch, output_true_batch]) d_on_g_losses.append(d_on_g_loss) d.trainable = True # print log print('d loss %f d_on_g loss %f' % (d_loss, d_on_g_loss[1] + d_on_g_loss[2])) if index % 50 == 0: # Test img_haze_test, img_clear_test = next( data_loader.test_generator) generated_images = g.predict(x=img_haze_test / 127.5 - 1, batch_size=batch_size) # 放缩为0-255 generated_images = (generated_images + 1) * 127.5 fig, axs = plt.subplots(batch_size, 3) for idx in range(batch_size): axs[idx, 0].imshow((img_haze_test[idx].astype('uint8'))) axs[idx, 0].axis('off') axs[idx, 0].set_title('haze') axs[idx, 1].imshow((img_clear_test[idx].astype('uint8'))) axs[idx, 1].axis('off') axs[idx, 1].set_title('origin') axs[idx, 2].imshow(generated_images[idx].astype('uint8')) axs[idx, 2].axis('off') axs[idx, 2].set_title('dehazed') fig.savefig("./dehazed_result/image/dehazed/%d-%d.jpg" % (epoch, index)) now = datetime.datetime.now() print(np.mean(d_losses), np.mean(d_on_g_losses), 'spend time %s' % (now - start)) # 保存所有权重 save_all_weights(d, g, epoch, int(np.mean(d_on_g_losses)))
def test(): """ 测试函数。计算指标 :return: """ # 构建网络模型 g = generator_model('test') # 加载模型权重 load_saved_weight(g) ########################################## # 测试集新代码。直接从jpg文件中读取,避免npy转 # case 1: 合成雾图去雾 生成去雾后的结果,并计算psnr,ssim # case 2: 真实雾图去雾 生成去雾后的结果 ########################################## def load_img_files(dir): """ 加载dir目录下的所有jpg后缀文件 :param dir: :return: array数组 """ file_paths = glob.glob(os.path.join(dir, '*.jpg')) imgs = [] for idx, file_path in enumerate(file_paths): imgs.append(np.array(Image.open(file_path).convert('RGB'))) return np.array(imgs) def predict(g, haze_imgs): """ 输入haze_imgs,用g预测clear_imgs。 之所以用这个函数,而不直接用g.predict,是为了适应haze_imgs中的img具有不同size的情况 :param g :param haze_imgs: 雾图 size bound是 0 - 255 :return: clear_imgs (每个clear_img可能具有不同的shape) size bound 是 0 -255 """ clear_imgs = [] for haze_img in haze_imgs: haze_img = np.expand_dims(haze_img, axis=0) clear_img = g.predict(haze_img / 127.5 - 1)[0] clear_imgs.append((clear_img + 1) * 127.5) return np.array(clear_imgs) mode = "real" # synthesis or real # 清晰图目录 clear_imgs_dir = '' # 雾图目录 haze_imgs_dir = '../test_imgs' # 去雾结果保存目录 dehaze_imgs_dir = '../test_imgs' if mode == "synthesis": clear_imgs = load_img_files(clear_imgs_dir) haze_imgs = load_img_files(haze_imgs_dir) # 去雾 generated_imgs = predict(g, haze_imgs) # 初始化指标 PSNR = 0 SSIM = 0 for idx, generated_img in enumerate(generated_imgs): dehazed_img = Image.fromarray(generated_img.astype('uint8')) dehazed_img.save( os.path.join(dehaze_imgs_dir, "%03d.jpg" % (idx + 1))) PSNR = PSNR + compare_psnr(clear_imgs[idx].astype('uint8'), generated_img.astype('uint8')) SSIM = SSIM + ssim(clear_imgs[idx].astype('uint8'), generated_img.astype('uint8'), multichannel=True) # 计算平均值 PSNR = PSNR / len(generated_imgs) SSIM = SSIM / len(generated_imgs) print('PSNR', PSNR) print('SSIM', SSIM) elif mode == 'real': haze_imgs = load_img_files(haze_imgs_dir) # 去雾 generated_imgs = predict(g, haze_imgs) for idx, generated_img in enumerate(generated_imgs): dehazed_img = Image.fromarray(generated_img.astype('uint8')) dehazed_img.save( os.path.join(dehaze_imgs_dir, "%03d.jpg" % (idx + 1)))
def test(): """ 测试函数。计算指标 :return: """ # 构建网络模型 g = generator_model() # 加载模型权重 load_saved_weight(g) ########################################## # 测试集新代码。直接从jpg文件中读取,避免npy转 # case 1: 合成雾图去雾 生成去雾后的结果,并计算psnr,ssim # case 2: 真实雾图去雾 生成去雾后的结果 ########################################## def load_img_files(dir): """ 加载dir目录下的所有jpg后缀文件 :param dir: :return: array数组 """ file_paths = glob.glob(os.path.join(dir, '*.jpg')) file_num = len(file_paths) imgs = np.zeros((file_num, img_height, img_width, 3)) for idx, file_path in enumerate(file_paths): imgs[idx] = np.array(Image.open(file_path).convert('RGB')) return imgs mode = "synthesis" # synthesis or real # 清晰图目录 clear_imgs_dir = 'D:/Projects/Dehaze/其他论文去雾代码/HazeRD合成测试集/clear' # 雾图目录 haze_imgs_dir = 'D:/Projects/Dehaze/其他论文去雾代码/HazeRD合成测试集/haze' # 去雾结果保存目录 dehaze_imgs_dir = 'D:/Projects/Dehaze/自己论文去雾代码/DeBulrGanToDeHaze/script/HazeRD合成雾图去雾结果' if mode == "synthesis": clear_imgs = load_img_files(clear_imgs_dir) haze_imgs = load_img_files(haze_imgs_dir) # 去雾 generated_imgs = g.predict(haze_imgs / 127.5 - 1) generated_imgs = (generated_imgs + 1) * 127.5 # 初始化指标 PSNR = 0 SSIM = 0 for idx, generated_img in enumerate(generated_imgs): dehazed_img = Image.fromarray(generated_img.astype('uint8')) dehazed_img.save(os.path.join(dehaze_imgs_dir, "%03d.jpg" % (idx + 1))) PSNR = PSNR + compare_psnr(clear_imgs[idx].astype('uint8'), generated_img.astype('uint8')) SSIM = SSIM + ssim(clear_imgs[idx].astype('uint8'), generated_img.astype('uint8'), multichannel=True) # 计算平均值 PSNR = PSNR / len(generated_imgs) SSIM = SSIM / len(generated_imgs) print('PSNR',PSNR) print('SSIM',SSIM) elif mode == 'real': haze_imgs = load_img_files(haze_imgs_dir) # 去雾 generated_imgs = g.predict(haze_imgs / 127.5 - 1) generated_imgs = (generated_imgs + 1) * 127.5 for idx, generated_img in enumerate(generated_imgs): dehazed_img = Image.fromarray(generated_img.astype('uint8')) dehazed_img.save(os.path.join(dehaze_imgs_dir, "%03d.jpg" % (idx + 1)))