def test(): dirpath = r"/media/qjc/D/data/testimgs_stereo/" # img1 = load_image(dirpath + "im0.png").transpose(2, 0, 1)[None]/255.0 # img2 = load_image(dirpath + "im1.png").transpose(2, 0, 1)[None]/255.0 # disp1 = load_disp(dirpath + "disp0.pfm")[None, None]/1.0 # disp2 = load_disp(dirpath + "disp1.pfm")[None, None]/1.0 img1 = load_image(dirpath + "im2.ppm").transpose(2, 0, 1)[None] / 255.0 img2 = load_image(dirpath + "im6.ppm").transpose(2, 0, 1)[None] / 255.0 disp1 = load_disp(dirpath + "disp2.pgm")[None, None] / 8.0 disp2 = load_disp(dirpath + "disp6.pgm")[None, None] / 8.0 # img1 = load_image(dirpath + "10L.png").transpose(2, 0, 1)[None]/255.0 # img2 = load_image(dirpath + "10R.png").transpose(2, 0, 1)[None]/255.0 # disp1 = load_disp(dirpath + "disp10L.png")[None, None] # disp2 = load_disp(dirpath + "disp10R.png")[None, None] im1 = to_tensor(img1) im2 = to_tensor(img2) d1 = to_tensor(disp1) d2 = to_tensor(disp2) im1t = imwrap_BCHW(im2, -d1) im2t = imwrap_BCHW(im1, d2) ssim = SSIM(window_size=11) ssim1 = ssim(im1, im1t) ssim2 = ssim(im2, im2t) ssim3 = ssim(im1, im2) abs1 = torch.abs(im1 - im1t).sum(dim=1, keepdim=True) abs2 = torch.abs(im2 - im2t).sum(dim=1, keepdim=True) print ssim1.shape, ssim2.shape print ssim1.mean().data[0], ssim2.mean().data[0], ssim3.mean().data[0] imsplot_tensor(im1, im2, im1t, im2t, 1 - ssim1, 1 - ssim2, abs1, abs2) plt.show()
def __init__(self): super(loss_stereo, self).__init__() self.w_ap = 1.0 self.w_ds = 0.001 self.w_lr = 0.001 self.w_m = 0.0001 self.ssim = SSIM()
def __init__(self, generator, criterion, optimizer): self.generator = generator self.criterion = criterion self.optimizer = optimizer self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.accuracy = SSIM() self.iteration = int(0) self.tensoration = torchvision.transforms.ToTensor() config = str(generator.__class__.__name__) + '_' + str( generator.deconv1.__class__.__name__) + '_' + str( generator.activation.__class__.__name__) config += '_' + str(criterion.__class__.__name__) config += "_" + str(optimizer.__class__.__name__) directory = './RESULTS/' reportPath = os.path.join(directory, config + "/report/") flag = os.path.exists(reportPath) if flag != True: os.makedirs(reportPath) print('os.makedirs("reportPath")') self.modelPath = os.path.join(directory, config + "/model/") flag = os.path.exists(self.modelPath) if flag != True: os.makedirs(self.modelPath) print('os.makedirs("/modelPath/")') self.images = os.path.join(directory, config + "/images/") flag = os.path.exists(self.images) if flag != True: os.makedirs(self.images) print('os.makedirs("/images/")') else: shutil.rmtree(self.images) self.report = open(reportPath + '/' + config + "_Report.txt", "w") _stdout = sys.stdout sys.stdout = self.report print(config) print(generator) print(criterion) self.report.flush() sys.stdout = _stdout self.generator.to(self.device) self.cudas = list(range(torch.cuda.device_count())) print(self.device) print(torch.cuda.device_count())
net = net.type(dtype) ''' k_net: n_k = 200 net_input_kernel = get_noise(n_k, INPUT, (1, 1)).type(dtype) net_input_kernel.squeeze_() ''' net_kernel = Predictor(1, 128, opt.kernel_size[0] * opt.kernel_size[1]) net_kernel = net_kernel.type(dtype) # Losses mse = torch.nn.MSELoss().type(dtype) ssim = SSIM().type(dtype) # optimizer optimizer = torch.optim.Adam([{ 'params': net.parameters() }, { 'params': net_kernel.parameters(), 'lr': 1e-4 }], lr=LR) scheduler = MultiStepLR(optimizer, milestones=[800, 1400, 2000], gamma=0.5) # learning rates # initilization inputs net_input_saved = net_input.detach().clone() #net_input_kernel_saved = net_input_kernel.detach().clone()
def train(config): dehaze_net = networks.IRDN(config.recurrent_iter).cuda() if config.epoched == 0: pass else: dehaze_net.load_state_dict( torch.load('trained_model/i6-outdoor-MSE+SSIM/Epoch%i.pth' % config.epoched)) if config.in_or_out == "outdoor": train_dataset = dataloader.dehazing_loader(config.orig_images_path, config.hazy_images_path) else: config.orig_images_path = "dataset/train_data/indoor/clear/" config.hazy_images_path = "dataset/train_data/indoor/hazy/" train_dataset = dataloader.dehazing_loader(config.orig_images_path, config.hazy_images_path) val_dataset = dataloader.dehazing_loader(config.orig_images_path, config.hazy_images_path, mode="val") train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True) if config.lossfunc == "MSE": criterion = nn.MSELoss().cuda() elif config.lossfunc == "SSIM": criterion = SSIM() else: #MSE+SSIM Loss criterion1 = nn.MSELoss().cuda() criterion2 = SSIM() comput_ssim = SSIM() optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=config.lr) dehaze_net.train() zt = 1 Iters = 0 indexX = [] indexY = [] for epoch in range(config.epoched, config.num_epochs): print("*" * 80 + "第%i轮" % epoch + "*" * 80) for iteration, (img_orig, img_haze) in enumerate(train_loader): img_orig = img_orig.cuda() img_haze = img_haze.cuda() try: clean_image, _ = dehaze_net(img_haze) if config.lossfunc == "MSE": loss = criterion(clean_image, img_orig) elif config.lossfunc == "SSIM": loss = criterion(img_orig, clean_image) loss = -loss else: ssim = criterion2(img_orig, clean_image) mse = criterion1(clean_image, img_orig) loss = mse - ssim del clean_image, img_orig optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(dehaze_net.parameters(), config.grad_clip_norm) optimizer.step() Iters += 1 if ((iteration + 1) % config.display_iter) == 0: print("Loss at iteration", iteration + 1, ":", loss.item()) if ((iteration + 1) % config.snapshot_iter) == 0: torch.save( dehaze_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + '.pth') except RuntimeError as e: if 'out of memory' in str(e): print(e) torch.cuda.empty_cache() else: raise e # if zt == 0 and Iters >= 700: #early stop # break _ssim = [] #Validation Stage with torch.no_grad(): for iteration, (clean, haze) in enumerate(val_loader): clean = clean.cuda() haze = haze.cuda() clean_, _ = dehaze_net(haze) _s = comput_ssim(clean, clean_) #计算ssim值 _ssim.append(_s.item()) torchvision.utils.save_image( torch.cat((haze, clean_, clean), 0), config.sample_output_folder + "/epoch%s" % epoch + "/" + str(iteration + 1) + ".jpg") _ssim = np.array(_ssim) print("-----The %i Epoch mean-ssim is :%f-----" % (epoch, np.mean(_ssim))) with open("trainlog/indoor/i%i_%s.log" % (config.recurrent_iter, config.lossfunc), "a+", encoding="utf-8") as f: s = "The %i Epoch mean-ssim is :%f" % (epoch, np.mean(_ssim)) + "\n" f.write(s) indexX.append(epoch + 1) indexY.append(np.mean(_ssim)) print(indexX, indexY) plt.plot(indexX, indexY, linewidth=2) plt.pause(0.1) plt.savefig("trainlog/i%i_%s.png" % (config.recurrent_iter, config.lossfunc)) torch.save(dehaze_net.state_dict(), config.snapshots_folder + "IRDN.pth")
def main(): print('Loading dataset ...\n') dataset_train = TrainValDataset("train") loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=opt.batch_size, shuffle=True, drop_last=True) dataset_val = TrainValDataset("val") loader_val = DataLoader(dataset=dataset_val, num_workers=4, batch_size=opt.batch_size, shuffle=True, drop_last=True) print("# of training samples: %d\n" % int(len(dataset_train))) print("# of valing samples: %d\n" % int(len(dataset_val))) # Build model model = Net() print_network(model) # loss function criterion1 = nn.MSELoss(size_average=True) criterion = SSIM() # Move to GPU if opt.use_gpu: model = model.cuda() criterion.cuda() criterion1 = nn.MSELoss(size_average=True).cuda() l1_loss = torch.nn.SmoothL1Loss().cuda() # Optimizer optimizer = optim.Adam(model.parameters(), lr=opt.lr) scheduler = MultiStepLR(optimizer, milestones=opt.milestone, gamma=0.5) # learning rates # record training writer = SummaryWriter(opt.save_path) # load the lastest model # initial_epoch = findLastCheckpoint(save_dir=opt.save_path) # if initial_epoch > 0: # print('resuming by loading epoch %d' % initial_epoch) # model.load_state_dict(torch.load(os.path.join(opt.save_path, 'net_epoch%d.pth' % initial_epoch))) initial_step = findLastCheckpoint_step(save_dir=opt.save_path) if initial_step > 0: print('resuming by loading step %d' % initial_step) model.load_state_dict( torch.load( os.path.join(opt.save_path, 'net_step%d.pth' % initial_step))) # start training step = initial_step p = 0 train_loss_sum = 0 for epoch in range(opt.epochs): scheduler.step(step) for param_group in optimizer.param_groups: print('learning rate %f' % param_group["lr"]) ## epoch training start for i, (input, target) in enumerate(loader_train, 0): # training step model.train() model.zero_grad() optimizer.zero_grad() if opt.use_gpu: input_train, target_train = Variable(input.cuda()), Variable( target.cuda()) else: input_train, target_train = Variable(input), Variable(target) out_train = model(input_train) loss_mse = criterion1(out_train, target_train) loss_ssim = criterion(out_train, target_train) loss = loss_mse + 0.2 * (1 - loss_ssim) loss.backward() optimizer.step() if i % 2 == 0: p = p + 1 train_loss_sum = train_loss_sum + loss.item() # training curve out_train = torch.clamp(out_train, 0., 1.) psnr_train = batch_PSNR(out_train, target_train, 1.) print( "[epoch %d][%d/%d] loss: %.4f,PSNR_train: %.4f" % (epoch + 1, i + 1, len(loader_train), loss.item(), psnr_train)) model.eval() # log the images if step % 10 == 0: # Log the scalar values writer.add_scalar('loss', loss.item(), step) writer.add_scalar('PSNR on training data', psnr_train, step) if step % opt.iter_epoch == 0: # Log the scalar values epoch_loss = train_loss_sum / p writer.add_scalar('epoch_loss', epoch_loss, step) p = 0 train_loss_sum = 0 if step % opt.save_freq == 0: torch.save(model.state_dict(), os.path.join(opt.save_path, 'net_latest.pth')) torch.save( model.state_dict(), os.path.join(opt.save_path, 'net_step%d.pth' % (step))) step += 1
import loaddata from PSNR import PSNR from SSIM import SSIM if __name__ == '__main__': tfilename = input('--target_dir:') ofilename = input('--output_dir:') target_data = loaddata.Dataset(tfilename) output_data = loaddata.Dataset(ofilename) result = open('result.xls', 'w') result.write('Name\tPSNR\tSSIM\n') for i in range(1): #target_data.getlen()): output_img = output_data.getitem(i) target_img = target_data.getitem(i) name = target_data.getname(i) print("NAME:", name) psnr = PSNR(output_img, target_img) print("PSNR:", psnr) ssim = SSIM(output_img, target_img) print("SSIM:", ssim) result.write(name + '\t' + str(psnr) + '\t' + str(ssim) + '\n') result.close()
def train(config): dehaze_net = model.MSDFN().cuda() dehaze_net.apply(weights_init) train_dataset = dataloader.dehazing_loader(config.orig_images_path, config.hazy_images_path, config.label_images_path) val_dataset = dataloader.dehazing_loader(config.orig_images_path_val, config.hazy_images_path_val, config.label_images_path_val, mode="val") train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True) criterion = SSIM() comput_ssim = SSIM() dehaze_net.train() zt = 1 Iters = 0 indexX = [] indexY = [] for epoch in range(1, config.num_epochs): if epoch == 0: config.lr = 0.0001 elif epoch == 1: config.lr = 0.00009 elif epoch > 1 and epoch <= 3: config.lr = 0.00006 elif epoch > 3 and epoch <= 5: config.lr = 0.00003 elif epoch > 5 and epoch <= 7: config.lr = 0.00001 elif epoch > 7 and epoch <= 9: config.lr = 0.000009 elif epoch > 9 and epoch <= 11: config.lr = 0.000006 elif epoch > 11 and epoch <= 13: config.lr = 0.000003 elif epoch > 13: config.lr = 0.000001 optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=config.lr) print("now lr == %f" % config.lr) print("*" * 80 + "第%i轮" % epoch + "*" * 80) for iteration, (img_clean, img_haze, img_depth) in enumerate(train_loader): img_clean = img_clean.cuda() img_haze = img_haze.cuda() img_depth = img_depth.cuda() try: clean_image = dehaze_net(img_haze, img_depth) if config.lossfunc == "MSE": loss = criterion(clean_image, img_clean) # MSE损失 else: loss = criterion(img_clean, clean_image) # -SSIM损失 loss = -loss # indexX.append(loss.item()) # indexY.append(iteration) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(dehaze_net.parameters(), config.grad_clip_norm) optimizer.step() Iters += 1 if ((iteration + 1) % config.display_iter) == 0: print("Loss at iteration", iteration + 1, ":", loss.item()) if ((iteration + 1) % config.snapshot_iter) == 0: torch.save( dehaze_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + '.pth') except RuntimeError as e: if 'out of memory' in str(e): print(e) torch.cuda.empty_cache() else: raise e _ssim = [] print("start Val!") #Validation Stage with torch.no_grad(): for iteration1, (img_clean, img_haze, img_depth) in enumerate(val_loader): print("va1 : %s" % str(iteration1)) img_clean = img_clean.cuda() img_haze = img_haze.cuda() img_depth = img_depth.cuda() clean_image = dehaze_net(img_haze, img_depth) _s = comput_ssim(img_clean, clean_image) _ssim.append(_s.item()) torchvision.utils.save_image( torch.cat((img_haze, img_clean, clean_image), 0), config.sample_output_folder + "/epoch%s" % epoch + "/" + str(iteration1 + 1) + ".jpg") torchvision.utils.save_image( clean_image, config.sample_output_folder + "/epoch%s" % epoch + "/" + str(iteration1 + 1) + ".jpg") _ssim = np.array(_ssim) print("-----The %i Epoch mean-ssim is :%f-----" % (epoch, np.mean(_ssim))) with open("trainlog/%s%s.log" % (config.lossfunc, config.actfuntion), "a+", encoding="utf-8") as f: s = "[%i,%f]" % (epoch, np.mean(_ssim)) + "\n" f.write(s) indexX.append(epoch + 1) indexY.append(np.mean(_ssim)) print(indexX, indexY) plt.plot(indexX, indexY, linewidth=2) plt.pause(0.1) plt.savefig("trainlog/%s%s.png" % (config.lossfunc, config.actfuntion)) torch.save(dehaze_net.state_dict(), config.snapshots_folder + "MSDFN.pth")
def main(): print('Loading dataset ...\n') dataset_train = Dataset(data_path=opt.data_path) loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=opt.batch_size, shuffle=True) print("# of training samples: %d\n" % int(len(dataset_train))) # Build model model = PRN_r(recurrent_iter=opt.recurrent_iter, use_GPU=opt.use_gpu) print_network(model) # loss function # criterion = nn.MSELoss(size_average=False) criterion = SSIM() # Move to GPU if opt.use_gpu: model = model.cuda() criterion.cuda() # Optimizer optimizer = optim.Adam(model.parameters(), lr=opt.lr) scheduler = MultiStepLR(optimizer, milestones=opt.milestone, gamma=0.2) # learning rates # record training writer = SummaryWriter(opt.save_path) # load the lastest model initial_epoch = findLastCheckpoint(save_dir=opt.save_path) if initial_epoch > 0: print('resuming by loading epoch %d' % initial_epoch) model.load_state_dict( torch.load( os.path.join(opt.save_path, 'net_epoch%d.pth' % initial_epoch))) # start training step = 0 for epoch in range(initial_epoch, opt.epochs): scheduler.step(epoch) for param_group in optimizer.param_groups: print('learning rate %f' % param_group["lr"]) ## epoch training start for i, (input_train, target_train) in enumerate(loader_train, 0): model.train() model.zero_grad() optimizer.zero_grad() input_train, target_train = Variable(input_train), Variable( target_train) if opt.use_gpu: input_train, target_train = input_train.cuda( ), target_train.cuda() out_train, _ = model(input_train) pixel_metric = criterion(target_train, out_train) loss = -pixel_metric loss.backward() optimizer.step() # training curve model.eval() out_train, _ = model(input_train) out_train = torch.clamp(out_train, 0., 1.) psnr_train = batch_PSNR(out_train, target_train, 1.) print( "[epoch %d][%d/%d] loss: %.4f, pixel_metric: %.4f, PSNR: %.4f" % (epoch + 1, i + 1, len(loader_train), loss.item(), pixel_metric.item(), psnr_train)) if step % 10 == 0: # Log the scalar values writer.add_scalar('loss', loss.item(), step) writer.add_scalar('PSNR on training data', psnr_train, step) step += 1 ## epoch training end # log the images model.eval() out_train, _ = model(input_train) out_train = torch.clamp(out_train, 0., 1.) im_target = utils.make_grid(target_train.data, nrow=8, normalize=True, scale_each=True) im_input = utils.make_grid(input_train.data, nrow=8, normalize=True, scale_each=True) im_derain = utils.make_grid(out_train.data, nrow=8, normalize=True, scale_each=True) writer.add_image('clean image', im_target, epoch + 1) writer.add_image('rainy image', im_input, epoch + 1) writer.add_image('deraining image', im_derain, epoch + 1) # save model torch.save(model.state_dict(), os.path.join(opt.save_path, 'net_latest.pth')) if epoch % opt.save_freq == 0: torch.save( model.state_dict(), os.path.join(opt.save_path, 'net_epoch%d.pth' % (epoch + 1)))
if __name__ == "__main__": # Test on real images. imgFn_0 = "/home/yaoyu/Transient/CTBridgeGirder02Aug27Side01_05/000122/000122_L_VC.png" imgFn_1 = "/home/yaoyu/Transient/CTBridgeGirder02Aug27Side01_05/000122/000122_R_VC.png" # Open the two images. img_0 = cv2.imread(imgFn_0, cv2.IMREAD_UNCHANGED) img_1 = cv2.imread(imgFn_1, cv2.IMREAD_UNCHANGED) print("Read \n%s and\n%s" % ( imgFn_0, imgFn_1 )) # Convert the image into torch Tensor. img_0 = cv_2_tensor(img_0, dtype=np.float32, flagCuda=True) img_1 = cv_2_tensor(img_1, dtype=np.float32, flagCuda=True) with TorchTracemalloc() as tt: # SSIM. ssim3 = SSIM(channel=3) res = ssim3(img_0, img_1) print("res.size() = {}.".format(res.size())) res = res.squeeze(0).squeeze(0).cpu().numpy() # Statistics. print("res: mean = {}, max = {}, min = {}.".format( res.mean(), res.max(), res.min() )) res2 = ssim3(img_0, img_1) print("GPU memory used = %fMB, peaked = %fMB. " % ( tt.used, tt.peaked ))
if opt.use_gpu: os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id try: os.makedirs(opt.log_dir) except OSError: pass try: os.makedirs(opt.model_dir) except OSError: pass cudnn.benchmark = True criterion = SSIM() def sample_generator(netED, gt): random_z = torch.randn(gt.shape[0], opt.nz).cuda() rain_make = netED.sample(random_z) #extract G input_make = rain_make + gt return input_make def train_model(net, netED, datasets, optimizer, lr_scheduler): NReal = ceil(opt.batchSize / (1 + opt.fake_ratio)) data_loader = DataLoader(datasets, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers),
def derain_loss(derain_output_tensor, clean_tensor): criterion = SSIM().cuda() criterion1 = nn.L1Loss().cuda() loss = 0.85 * (1 - criterion(derain_output_tensor, clean_tensor)) + ( 1 - 0.85) * criterion1(derain_output_tensor, clean_tensor) return loss
activation_types = { 'ReLU': nn.ReLU(), 'Leaky': nn.LeakyReLU(), 'PReLU': nn.PReLU(), 'ELU': nn.ELU(), 'SELU': nn.SELU(), 'SILU': SILU() } optimizer_types = { 'Adam': optim.Adam, 'RMSprop': optim.RMSprop, 'SGD': optim.SGD } np.random.seed(42) accuracy = SSIM(args.dimension) model = generator_types[args.generator] deconvLayer = (deconv_types[args.deconv] if args.deconv in deconv_types else deconv_types['upsample']) function = (activation_types[args.activation] if args.activation in activation_types else activation_types['Leaky']) generator = model(dimension=args.dimension, deconv=deconvLayer, activation=function, drop_out=args.drop_out) criterion = criterion_types[args.criterion](dimension=args.dimension) deconvolution_dataset = operation_types[args.operation] augmentations = {'train': True, 'val': False} shufles = {'train': True, 'val': False}
net = net.cuda() #define the network FCN n_k = 200 #pdb.set_trace() net_input_kernel = get_noise(n_k, INPUT, (1, 1)) net_input_kernel = net_input_kernel.squeeze().detach() net_kernel = fcn(n_k, opt.kernel_size[0] * opt.kernel_size[1]) net_kernel = net_kernel.cuda() # losses mse = torch.nn.MSELoss() ssim = SSIM() mse = mse.cuda() ssim = ssim.cuda() # optimizer optimizer = torch.optim.Adam([{ 'params': net.parameters() }, { 'params': net_kernel.parameters(), 'lr': 1e-4 }], lr=LR) scheduler = MultiStepLR(optimizer, milestones=[2000, 3000], gamma=0.5) # learning rates # initilization inputs
save = 200 scales = 1 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("device = %s" % device) reg_noise_std = 0.001 LR = 0.02 l = True if channels == 1 else False patchsize = 35 kernel_size = [initial_kernel_size, initial_kernel_size] img_size = [initial_image_size, initial_image_size] input_size = (img_size[0] + kernel_size[0] - 1, img_size[1] + kernel_size[1] - 1) scale = 1 pixelloss = torch.nn.MSELoss() ssim = SSIM() gradop = [[-1.0, -1.0, -1.0], [-1.0, 8.0, -1.0], [-1.0, -1.0, -1.0]] lap = torch.from_numpy(np.array(gradop).astype(np.float32)).to(device) lap = lap.unsqueeze(0).unsqueeze(0) if channels > 1: lap = torch.cat([lap] * channels, 0) lap.requires_grad = False lap.trainable = False ratio = 1.1 max_g_weight = 16 latest = 3000 dc_weight = 0.1 while scale <= scales: l0_weight = 1 g_weight = ratio * l0_weight print("current_scale = %s, ratio = %s" % (scale, ratio))
np.save('a.npy', train_dataset_array_a) np.save('b.npy', train_dataset_array_b) save = True itera = iter(Iterator(train_dataset_array_a, args.batch_size)) iterb = iter(Iterator(train_dataset_array_b, args.batch_size)) model = AutoEncoder(image_channels=3).to(device) discriminator = Discriminator().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) optimizer_b = torch.optim.Adam(model.parameters(), lr=1e-3) mse = nn.L1Loss() ssim_loss = SSIM() def dis_loss(prob_real_is_real, prob_fake_is_real): EPS = 1e-12 return torch.mean(-(torch.log(prob_real_is_real + EPS) + torch.log(1 - prob_fake_is_real + EPS))) def gen_loss(original, recon_structed, validity=None): ssim_l = -ssim_loss(recon_structed, original) if validity: gen_loss_GAN = torch.mean(-torch.log(validity + 1e-12)) # gen_loss_L1 = torch.mean(torch.abs(original - recon_structed)) return 5 * ssim_l + gen_loss_GAN else: return ssim_l