def train_main(args): global loader_train, loader_val train_dataset = dataloader.dehazing_loader(args.original_pic_root, args.haze_pic_root) val_dataset = dataloader.dehazing_loader(args.original_pic_root, args.haze_pic_root, mode="val") loader_train = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) loader_val = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4) device = set_device() setup_seed(RANDOM_SEED) #随机种子 model = AOD() #model = nn.DataParallel(model) #多gpu criterion = nn.MSELoss() params = net_lr(model, FC_LR, NET_LR) if OPTIMIZER == 'adam': optimizer = torch.optim.Adam(params, betas=(0.9, 0.999), weight_decay=0, eps=1e-08) else: optimizer = torch.optim.SGD(params, momentum=MOMENTUM, nesterov=True, weight_decay=WEIGHT_DECAY) print(model) start_epoch = 0 if Load_model: start_epoch = 25 filepath = 'load_model_path' model = load_model(model, filepath, device=device) model = model.to(device=device) optimizer = load_optimizer(optimizer, filepath, device=device) train(model, optimizer, criterion, device=device, epochs=EPOCH, start=start_epoch)
def train(config): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # device = torch.device('cpu') dehaze_net = net.dehaze_net().to(device) dehaze_net.apply(weights_init) train_dataset = dataloader.dehazing_loader(config.orig_images_path, config.hazy_images_path) val_dataset = dataloader.dehazing_loader(config.orig_images_path, config.hazy_images_path, mode="val") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) criterion = nn.MSELoss().to(device) optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=config.lr, weight_decay=config.weight_decay) dehaze_net.train() for epoch in range(config.num_epochs): for iteration, (img_orig, img_haze) in enumerate(train_loader): img_orig = img_orig.to(device) img_haze = img_haze.to(device) clean_image = dehaze_net(img_haze) loss = criterion(clean_image, img_orig) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(dehaze_net.parameters(),config.grad_clip_norm) optimizer.step() if ((iteration+1) % config.display_iter) == 0: print("epoch", epoch, " Loss at iteration", iteration+1, ":", loss.item()) if ((iteration+1) % config.snapshot_iter) == 0: torch.save(dehaze_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + 'iteration' + str(iteration+1)+ '.pth') # Validation Stage for iter_val, (img_orig, img_haze) in enumerate(val_loader): img_orig = img_orig.to(device) img_haze = img_haze.to(device) clean_image = dehaze_net(img_haze) torchvision.utils.save_image(torch.cat((img_haze, clean_image, img_orig),0), config.sample_output_folder+str(iter_val+1)+".jpg") torch.save(dehaze_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + 'iteration' + str(iteration+1)+ '.pth') torch.save(dehaze_net.state_dict(), config.snapshots_folder + "dehazer.pth")
def train(args): unfog_net = net.unfog_net().cuda() unfog_net.apply(init) train_dataset = dataloader.dehazing_loader(args.orig_images_path, args.hazy_images_path) val_dataset = dataloader.dehazing_loader(args.orig_images_path, args.hazy_images_path, mode="val") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) criterion = nn.MSELoss().cuda() optimizer = torch.optim.Adam(unfog_net.parameters(), lr=args.lr, weight_decay=args.weight_decay) unfog_net.train() for epoch in range(args.num_epochs): for iteration, (img_orig, img_fog) in enumerate(train_loader): img_orig = img_orig.cuda() img_fog = img_fog.cuda() clean_image = unfog_net(img_fog) loss = criterion(clean_image, img_orig) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(unfog_net.parameters(),args.grad_clip_norm) optimizer.step() if ((iteration+1) % args.display_iter) == 0: print("Loss at iteration", iteration+1, ":", loss.item()) if ((iteration+1) % args.snapshot_iter) == 0: torch.save(unfog_net.state_dict(), args.snapshots_folder + "Epoch" + str(epoch) + '.pth') # Validation Stage for iter_val, (img_orig, img_fog) in enumerate(val_loader): img_orig = img_orig.cuda() img_fog = img_fog.cuda() clean_image = unfog_net(img_fog) torchvision.utils.save_image(torch.cat((img_fog, clean_image, img_orig),0), args.sample_output_folder+str(iter_val+1)+".jpg") torch.save(unfog_net.state_dict(), args.snapshots_folder + "net.pth")
def train(config): dehaze_net = networks.IRDN(config.recurrent_iter).cuda() if config.epoched == 0: pass else: dehaze_net.load_state_dict( torch.load('trained_model/i6-outdoor-MSE+SSIM/Epoch%i.pth' % config.epoched)) if config.in_or_out == "outdoor": train_dataset = dataloader.dehazing_loader(config.orig_images_path, config.hazy_images_path) else: config.orig_images_path = "dataset/train_data/indoor/clear/" config.hazy_images_path = "dataset/train_data/indoor/hazy/" train_dataset = dataloader.dehazing_loader(config.orig_images_path, config.hazy_images_path) val_dataset = dataloader.dehazing_loader(config.orig_images_path, config.hazy_images_path, mode="val") train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True) if config.lossfunc == "MSE": criterion = nn.MSELoss().cuda() elif config.lossfunc == "SSIM": criterion = SSIM() else: #MSE+SSIM Loss criterion1 = nn.MSELoss().cuda() criterion2 = SSIM() comput_ssim = SSIM() optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=config.lr) dehaze_net.train() zt = 1 Iters = 0 indexX = [] indexY = [] for epoch in range(config.epoched, config.num_epochs): print("*" * 80 + "第%i轮" % epoch + "*" * 80) for iteration, (img_orig, img_haze) in enumerate(train_loader): img_orig = img_orig.cuda() img_haze = img_haze.cuda() try: clean_image, _ = dehaze_net(img_haze) if config.lossfunc == "MSE": loss = criterion(clean_image, img_orig) elif config.lossfunc == "SSIM": loss = criterion(img_orig, clean_image) loss = -loss else: ssim = criterion2(img_orig, clean_image) mse = criterion1(clean_image, img_orig) loss = mse - ssim del clean_image, img_orig optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(dehaze_net.parameters(), config.grad_clip_norm) optimizer.step() Iters += 1 if ((iteration + 1) % config.display_iter) == 0: print("Loss at iteration", iteration + 1, ":", loss.item()) if ((iteration + 1) % config.snapshot_iter) == 0: torch.save( dehaze_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + '.pth') except RuntimeError as e: if 'out of memory' in str(e): print(e) torch.cuda.empty_cache() else: raise e # if zt == 0 and Iters >= 700: #early stop # break _ssim = [] #Validation Stage with torch.no_grad(): for iteration, (clean, haze) in enumerate(val_loader): clean = clean.cuda() haze = haze.cuda() clean_, _ = dehaze_net(haze) _s = comput_ssim(clean, clean_) #计算ssim值 _ssim.append(_s.item()) torchvision.utils.save_image( torch.cat((haze, clean_, clean), 0), config.sample_output_folder + "/epoch%s" % epoch + "/" + str(iteration + 1) + ".jpg") _ssim = np.array(_ssim) print("-----The %i Epoch mean-ssim is :%f-----" % (epoch, np.mean(_ssim))) with open("trainlog/indoor/i%i_%s.log" % (config.recurrent_iter, config.lossfunc), "a+", encoding="utf-8") as f: s = "The %i Epoch mean-ssim is :%f" % (epoch, np.mean(_ssim)) + "\n" f.write(s) indexX.append(epoch + 1) indexY.append(np.mean(_ssim)) print(indexX, indexY) plt.plot(indexX, indexY, linewidth=2) plt.pause(0.1) plt.savefig("trainlog/i%i_%s.png" % (config.recurrent_iter, config.lossfunc)) torch.save(dehaze_net.state_dict(), config.snapshots_folder + "IRDN.pth")
def train(config): dehaze_net = model.MSDFN().cuda() dehaze_net.apply(weights_init) train_dataset = dataloader.dehazing_loader(config.orig_images_path, config.hazy_images_path, config.label_images_path) val_dataset = dataloader.dehazing_loader(config.orig_images_path_val, config.hazy_images_path_val, config.label_images_path_val, mode="val") train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True) criterion = SSIM() comput_ssim = SSIM() dehaze_net.train() zt = 1 Iters = 0 indexX = [] indexY = [] for epoch in range(1, config.num_epochs): if epoch == 0: config.lr = 0.0001 elif epoch == 1: config.lr = 0.00009 elif epoch > 1 and epoch <= 3: config.lr = 0.00006 elif epoch > 3 and epoch <= 5: config.lr = 0.00003 elif epoch > 5 and epoch <= 7: config.lr = 0.00001 elif epoch > 7 and epoch <= 9: config.lr = 0.000009 elif epoch > 9 and epoch <= 11: config.lr = 0.000006 elif epoch > 11 and epoch <= 13: config.lr = 0.000003 elif epoch > 13: config.lr = 0.000001 optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=config.lr) print("now lr == %f" % config.lr) print("*" * 80 + "第%i轮" % epoch + "*" * 80) for iteration, (img_clean, img_haze, img_depth) in enumerate(train_loader): img_clean = img_clean.cuda() img_haze = img_haze.cuda() img_depth = img_depth.cuda() try: clean_image = dehaze_net(img_haze, img_depth) if config.lossfunc == "MSE": loss = criterion(clean_image, img_clean) # MSE损失 else: loss = criterion(img_clean, clean_image) # -SSIM损失 loss = -loss # indexX.append(loss.item()) # indexY.append(iteration) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(dehaze_net.parameters(), config.grad_clip_norm) optimizer.step() Iters += 1 if ((iteration + 1) % config.display_iter) == 0: print("Loss at iteration", iteration + 1, ":", loss.item()) if ((iteration + 1) % config.snapshot_iter) == 0: torch.save( dehaze_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + '.pth') except RuntimeError as e: if 'out of memory' in str(e): print(e) torch.cuda.empty_cache() else: raise e _ssim = [] print("start Val!") #Validation Stage with torch.no_grad(): for iteration1, (img_clean, img_haze, img_depth) in enumerate(val_loader): print("va1 : %s" % str(iteration1)) img_clean = img_clean.cuda() img_haze = img_haze.cuda() img_depth = img_depth.cuda() clean_image = dehaze_net(img_haze, img_depth) _s = comput_ssim(img_clean, clean_image) _ssim.append(_s.item()) torchvision.utils.save_image( torch.cat((img_haze, img_clean, clean_image), 0), config.sample_output_folder + "/epoch%s" % epoch + "/" + str(iteration1 + 1) + ".jpg") torchvision.utils.save_image( clean_image, config.sample_output_folder + "/epoch%s" % epoch + "/" + str(iteration1 + 1) + ".jpg") _ssim = np.array(_ssim) print("-----The %i Epoch mean-ssim is :%f-----" % (epoch, np.mean(_ssim))) with open("trainlog/%s%s.log" % (config.lossfunc, config.actfuntion), "a+", encoding="utf-8") as f: s = "[%i,%f]" % (epoch, np.mean(_ssim)) + "\n" f.write(s) indexX.append(epoch + 1) indexY.append(np.mean(_ssim)) print(indexX, indexY) plt.plot(indexX, indexY, linewidth=2) plt.pause(0.1) plt.savefig("trainlog/%s%s.png" % (config.lossfunc, config.actfuntion)) torch.save(dehaze_net.state_dict(), config.snapshots_folder + "MSDFN.pth")
def train(config): use_gpu = config.use_gpu bk_width = config.block_width bk_height = config.block_height resize = config.resize bTest = config.bTest if use_gpu: dehaze_net = net.dehaze_net().cuda() else: dehaze_net = net.dehaze_net() if config.snap_train_data: dehaze_net.load_state_dict( torch.load(config.snapshots_train_folder + config.snap_train_data)) else: dehaze_net.apply(weights_init) print(dehaze_net) train_dataset = dataloader.dehazing_loader(config.orig_images_path, 'train', resize, bk_width, bk_height, bTest) val_dataset = dataloader.dehazing_loader(config.orig_images_path, "val", resize, bk_width, bk_height, bTest) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) if use_gpu: criterion = nn.MSELoss().cuda() else: criterion = nn.MSELoss() optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=config.lr, weight_decay=config.weight_decay) dehaze_net.train() # 同一組訓練資料跑 epoch 次 save_counter = 0 for epoch in range(config.num_epochs): # 有 iteration 張一起訓練. # img_orig , img_haze 是包含 iteration 個圖片的 tensor 資料集 , 訓練時會一口氣訓練 iteration 個圖片. # 有點像將圖片橫向拼起來 實際上是不同維度. if config.do_valid == 0: for iteration, (img_orig, img_haze, rgb, bl_num_width, bl_num_height, data_path) in enumerate(train_loader): if save_counter == 0: print("img_orig.size:") print(len(img_orig)) print("bl_num_width.type:") print(bl_num_width.type) print("shape:") print(bl_num_width.shape) # train stage num_width = int(bl_num_width[0].item()) num_height = int(bl_num_height[0].item()) full_bk_num = num_width * num_height display_block_iter = full_bk_num / config.display_block_iter for index in range(len(img_orig)): unit_img_orig = img_orig[index] unit_img_haze = img_haze[index] if save_counter == 0: print("unit_img_orig type:") print(unit_img_orig.type()) print("size:") print(unit_img_orig.size()) print("shape:") print(unit_img_orig.shape) ''' if bTest == 1: if save_counter ==0: numpy_ori = unit_img_orig.numpy().copy() print("data path:") print(data_path) print("index:"+str(index)) for i in range(3): for j in range(32): print("before:") print(numpy_ori[index][i][j]) print("after:") print(numpy_ori[index][i][j]*255) ''' if use_gpu: unit_img_orig = unit_img_orig.cuda() unit_img_haze = unit_img_haze.cuda() clean_image = dehaze_net(unit_img_haze) loss = criterion(clean_image, unit_img_orig) if torch.isnan(unit_img_haze).any() or torch.isinf( clean_image).any(): print("unit_img_haze:") print(unit_img_haze.shape) print(unit_img_haze) print("clean_image:") print(clean_image.shape) print(clean_image) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(dehaze_net.parameters(), config.grad_clip_norm) optimizer.step() # show loss every config.display_block_iter if ((index + 1) % display_block_iter) == 0: print("Loss at Epoch:" + str(epoch) + "_index:" + str(index + 1) + "/" + str(len(img_orig)) + "_iter:" + str(iteration + 1) + "_Loss value:" + str(loss.item())) # save snapshot every save_counter times if ((save_counter + 1) % config.snapshot_iter) == 0: save_name = "Epoch:" + str( epoch) + "_TrainTimes:" + str(save_counter + 1) + ".pth" torch.save(dehaze_net.state_dict(), config.snapshots_folder + save_name) # torch.save(dehaze_net.state_dict(), # config.snapshots_folder , "Epoch:", str(epoch), " # _TrainTimes:", str(save_counter+1), ".pth") save_counter = save_counter + 1 # Validation Stage # img_orig -> yuv444 # img_haze -> yuv420 for iter_val, (img_orig, img_haze, rgb, bl_num_width, bl_num_height, data_path) in enumerate(val_loader): sub_image_list = [] # after deep_learning image (yuv420) sub_image_list_no_deep = [] # yuv420 ori_sub_image_list = [] # yuv444 image rgb_image_list = [] # block ori image (rgb) rgb_list_from_sub = [] # rgb from clean image (yuv420) rgb_list_from_ori = [] # rgb from haze image (yuv420) for index in range(len(img_orig)): unit_img_orig = img_orig[index] unit_img_haze = img_haze[index] unit_img_rgb = rgb[index] # TODO: yuv444 ??? color is strange ... ''' if bTest == 1 and index == 0: numpy_ori = unit_img_orig.numpy().copy() print("data path:") print(data_path) print("index:" + str(index)) for i in range(3): for j in range(32): print(numpy_ori[index][i][j]) bTest = 0 ''' if use_gpu: unit_img_orig = unit_img_orig.cuda() unit_img_haze = unit_img_haze.cuda() unit_img_rgb = unit_img_rgb.cuda() clean_image = dehaze_net(unit_img_haze) sub_image_list.append(clean_image) sub_image_list_no_deep.append(unit_img_haze) ori_sub_image_list.append(unit_img_orig) rgb_image_list.append(unit_img_rgb) rgb_list_from_sub.append(yuv2rgb(clean_image)) rgb_list_from_ori.append(yuv2rgb(unit_img_haze)) print(data_path) temp_data_path = data_path[0] print('temp_data_path:') print(temp_data_path) orimage_name = temp_data_path.split("/")[-1] print(orimage_name) orimage_name = orimage_name.split(".")[0] print(orimage_name) num_width = int(bl_num_width[0].item()) num_height = int(bl_num_height[0].item()) full_bk_num = num_width * num_height # YUV420 & after deep learning # ------------------------------------------------------------------# image_all = torch.cat((sub_image_list[:num_width]), 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(sub_image_list[i:i + num_width], 3) image_all = torch.cat([image_all, image_row], 2) image_name = config.sample_output_folder + str( iter_val + 1) + "_yuv420_deep_learning.bmp" print(image_name) torchvision.utils.save_image( image_all, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_yuv420_deep.bmp") # ------------------------------------------------------------------# # YUV420 & without deep learning # ------------------------------------------------------------------# image_all_ori_no_deep = torch.cat( (sub_image_list_no_deep[:num_width]), 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(sub_image_list_no_deep[i:i + num_width], 3) image_all_ori_no_deep = torch.cat( [image_all_ori_no_deep, image_row], 2) image_name = config.sample_output_folder + str( iter_val + 1) + "_yuv420_ori.bmp" print(image_name) torchvision.utils.save_image( image_all_ori_no_deep, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_yuv420_ori.bmp") # ------------------------------------------------------------------# # YUV444 # ------------------------------------------------------------------# image_all_ori = torch.cat(ori_sub_image_list[:num_width], 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(ori_sub_image_list[i:i + num_width], 3) image_all_ori = torch.cat([image_all_ori, image_row], 2) image_name = config.sample_output_folder + str(iter_val + 1) + "_yuv444.bmp" print(image_name) # torchvision.utils.save_image(image_all_ori, image_name) torchvision.utils.save_image( image_all_ori, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_yuv444.bmp") # ------------------------------------------------------------------# # block rgb (test) # ------------------------------------------------------------------# rgb_image_all = torch.cat(rgb_image_list[:num_width], 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(rgb_image_list[i:i + num_width], 3) ''' image_row = torch.cat((ori_sub_image_list[i],ori_sub_image_list[i +1]), 1) for j in range(i+2, num_width): image_row = torch.cat((image_row, ori_sub_image_list[j]), 1) ''' rgb_image_all = torch.cat([rgb_image_all, image_row], 2) image_name = config.sample_output_folder + str(iter_val + 1) + "_rgb.bmp" print(image_name) torchvision.utils.save_image( rgb_image_all, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_rgb.bmp") # ------------------------------------------------------------------# # ------------------------------------------------------------------# rgb_from_420_image_all_clear = torch.cat( rgb_list_from_sub[:num_width], 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(rgb_list_from_sub[i:i + num_width], 3) rgb_from_420_image_all_clear = torch.cat( [rgb_from_420_image_all_clear, image_row], 2) image_name = config.sample_output_folder + str( iter_val + 1) + "_rgb_from_clean_420.bmp" print(image_name) torchvision.utils.save_image( rgb_from_420_image_all_clear, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "_rgb_from_clean_420.bmp") # ------------------------------------------------------------------# # ------------------------------------------------------------------# rgb_from_420_image_all_haze = torch.cat( rgb_list_from_ori[:num_width], 3) for i in range(num_width, full_bk_num, num_width): image_row = torch.cat(rgb_list_from_ori[i:i + num_width], 3) rgb_from_420_image_all_haze = torch.cat( [rgb_from_420_image_all_haze, image_row], 2) image_name = config.sample_output_folder + str( iter_val + 1) + "_rgb_from_haze_420.bmp" print(image_name) torchvision.utils.save_image( rgb_from_420_image_all_haze, config.sample_output_folder + "Epoch:" + str(epoch) + "_Index:" + str(iter_val + 1) + "_" + orimage_name + "__rgb_from_haze_420.bmp") # ------------------------------------------------------------------# # To compute PSNR as a measure, use lower case function from the library. # ------------------------------------------------------------------# # rgb_from_420_image_all_haze rgb_image_all # rgb_from_420_image_all_clear rgb_image_all psnr_index = piq.psnr(rgb_from_420_image_all_haze, rgb_image_all, data_range=1., reduction='none') print(f"PSNR haze: {psnr_index.item():0.4f}") psnr_index = piq.psnr(rgb_from_420_image_all_clear, rgb_image_all, data_range=1., reduction='none') print(f"PSNR clear: {psnr_index.item():0.4f}") # ------------------------------------------------------------------# # To compute SSIM as a measure, use lower case function from the library. # ------------------------------------------------------------------# ssim_index = piq.ssim(rgb_from_420_image_all_haze, rgb_image_all, data_range=1.) ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)( rgb_from_420_image_all_haze, rgb_image_all) print( f"SSIM haze index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}" ) ssim_index = piq.ssim(rgb_from_420_image_all_clear, rgb_image_all, data_range=1.) ssim_loss: torch.Tensor = piq.SSIMLoss(data_range=1.)( rgb_from_420_image_all_clear, rgb_image_all) print( f"SSIM clear index: {ssim_index.item():0.4f}, loss: {ssim_loss.item():0.4f}" ) # ------------------------------------------------------------------# torch.save(dehaze_net.state_dict(), config.snapshots_folder + "dehazer.pth")