def _test_loop(path, batch_size, datagen, img_height, img_width, iteration, large_img_height, large_img_width, model, total_psnr, prefix, nb_images): for x in datagen.flow_from_directory(path, class_mode=None, batch_size=batch_size, target_size=(large_img_width, large_img_height)): t1 = time.time() # resize images x_temp = x.copy() x_temp = x_temp.transpose((0, 2, 3, 1)) x_generator = np.empty((batch_size, img_width, img_height, 3)) for j in range(batch_size): img = imresize(x_temp[j], (img_width, img_height)) x_generator[j, :, :, :] = img x_generator = x_generator.transpose((0, 3, 1, 2)) output_image_batch = model.predict_on_batch(x_generator) average_psnr = 0.0 for x_i in range(batch_size): average_psnr += psnr( x[x_i], np.clip(output_image_batch[x_i] * 255, 0, 255) / 255.) total_psnr += average_psnr average_psnr /= batch_size iteration += batch_size t2 = time.time() print( "Time required : %0.2f. Average validation PSNR over %d samples = %0.2f" % (t2 - t1, batch_size, average_psnr)) for x_i in range(batch_size): real_path = base_test_images + prefix + "_iteration_%d_num_%d_real_.png" % ( iteration, x_i + 1) generated_path = base_test_images + prefix + "_iteration_%d_num_%d_generated.png" % ( iteration, x_i + 1) val_x = x[x_i].copy() * 255. val_x = val_x.transpose((1, 2, 0)) val_x = np.clip(val_x, 0, 255).astype('uint8') output_image = output_image_batch[x_i] * 255 output_image = output_image.transpose((1, 2, 0)) output_image = np.clip(output_image, 0, 255).astype('uint8') imsave(real_path, val_x) imsave(generated_path, output_image) if iteration >= nb_images: break return total_psnr
def _test_loop(path, batch_size, datagen, img_height, img_width, iteration, large_img_height, large_img_width, model, total_psnr, prefix, nb_images): for x in datagen.flow_from_directory(path, class_mode=None, batch_size=batch_size, target_size=(large_img_width, large_img_height)): t1 = time.time() # resize images x_temp = x.copy() x_temp = x_temp.transpose((0, 2, 3, 1)) x_generator = np.empty((batch_size, img_width, img_height, 3)) for j in range(batch_size): img = imresize(x_temp[j], (img_width, img_height)) x_generator[j, :, :, :] = img x_generator = x_generator.transpose((0, 3, 1, 2)) output_image_batch = model.predict_on_batch(x_generator) average_psnr = 0.0 for x_i in range(batch_size): average_psnr += psnr(x[x_i], output_image_batch[x_i] / 255.) total_psnr += average_psnr average_psnr /= batch_size iteration += batch_size t2 = time.time() print("Time required : %0.2f. Average validation PSNR over %d samples = %0.2f" % (t2 - t1, batch_size, average_psnr)) for x_i in range(batch_size): real_path = base_test_images + prefix + "_iteration_%d_num_%d_real_.png" % (iteration, x_i + 1) generated_path = base_test_images + prefix + "_iteration_%d_num_%d_generated.png" % (iteration, x_i + 1) val_x = x[x_i].copy() * 255. val_x = val_x.transpose((1, 2, 0)) val_x = np.clip(val_x, 0, 255).astype('uint8') output_image = output_image_batch[x_i] output_image = output_image.transpose((1, 2, 0)) output_image = np.clip(output_image, 0, 255).astype('uint8') imsave(real_path, val_x) imsave(generated_path, output_image) if iteration >= nb_images: break return total_psnr
def validate(loader, model, epoch, d1, d2, blind, noise_level): val_psnr = 0 val_ssim = 0 val_l1 = 0 model.train(False) k1 = model.weight[0].unsqueeze(0).expand(loader.batch_size, -1, -1, -1) k2 = model.weight[1].unsqueeze(0).expand(loader.batch_size, -1, -1, -1) d1 = d1.expand(loader.batch_size, -1, -1, -1) d2 = d2.expand(loader.batch_size, -1, -1, -1) # pre-create noise levels if blind: nls = np.linspace(0.5, noise_level, len(loader)) else: nls = noise_level * np.ones(len(loader)) with torch.no_grad(): for i, data in tqdm.tqdm(enumerate(loader)): x, y, k, d = data x = x.to(device) y = y.to(device) k = k.to(device) d = d.to(device) nl = nls[i] / 255 y += nl * torch.randn_like(y) y = y.clamp(0, 1) hat_x = model(y, k, d, k1, k2, d1, d2)[-1] hat_x.clamp_(0, 1) hat_x = utils.crop_valid(hat_x, k) x = utils.crop_valid(x, k) y = utils.crop_valid(y, k) val_psnr += loss.psnr(hat_x, x) val_ssim += loss.ssim(hat_x, x) val_l1 += F.l1_loss(hat_x, x).item() return val_psnr / len(loader), val_ssim / len(loader), val_l1 / len(loader)
def validate(loader, model, epoch, d1, d2, blind, noise_level): val_psnr = 0 val_ssim = 0 val_l1 = 0 model.train(False) k1 = model.weight[0].unsqueeze(0).expand(loader.batch_size, -1, -1, -1) k2 = model.weight[1].unsqueeze(0).expand(loader.batch_size, -1, -1, -1) d1 = d1.expand(loader.batch_size, -1, -1, -1) d2 = d2.expand(loader.batch_size, -1, -1, -1) # pre-create noise levels if blind: nls = np.linspace(0.5, noise_level, len(loader)) else: nls = noise_level * np.ones(len(loader)) with torch.no_grad(): for i, data in tqdm.tqdm(enumerate(loader)): x, y, mag, ori = data x = x.to(device) y = y.to(device) mag = mag.to(device) ori = ori.to(device) ori = (90 - ori).add(360).fmod(180) labels = utils.get_labels(mag, ori) ori = ori * np.pi / 180 nl = nls[i] / 255 y += nl * torch.randn_like(y) y = y.clamp(0, 1) hat_x = model(y, mag, ori, labels, k1, k2, d1, d2)[-1] hat_x.clamp_(0, 1) val_psnr += loss.psnr(hat_x, x) val_ssim += loss.ssim(hat_x, x) val_l1 += F.l1_loss(hat_x, x).item() return val_psnr / len(loader), val_ssim / len(loader), val_l1 / len(loader)
def eval(self): self.model.eval() psnr_losses = [] ssim_losses = [] with tqdm(total=len(self.valloader)) as t: # show batch evaluate process t.set_description('evaluating...') for LR, HR in self.valloader: LR = LR.to(self.device) HR = HR.to(self.device) output = self.model(LR) # # only calculation on y channel or gray # if self.r_mode == 'RGB' and self.img_channels == 3: # output = rgb2y_tensor(output) # HR = rgb2y_tensor(HR) ssim_loss = ssim(output, HR) psnr_loss = psnr(output, HR) # save losses ssim_losses.append(ssim_loss) psnr_losses.append(psnr_loss) t.update() avg_ssim = torch.stack(ssim_losses, dim=0).mean().item() avg_psnr = torch.stack(psnr_losses, dim=0).mean().item() t.set_postfix(avg_psnr=f'{avg_psnr:.010f}', avg_ssim=f'{avg_ssim:.010f}') t.set_description('evaluate') self.model.train() return avg_ssim, avg_psnr
def _train_model(self, image_dir, nb_images=80000, nb_epochs=10, pre_train_srgan=False, pre_train_discriminator=False, load_generative_weights=False, load_discriminator_weights=False, save_loss=True, disc_train_flip=0.1): assert self.img_width >= 16, "Minimum image width must be at least 16" assert self.img_height >= 16, "Minimum image height must be at least 16" if load_generative_weights: try: self.generative_model_.load_weights( self.generative_network.sr_weights_path) print("Generator weights loaded.") except: print("Could not load generator weights.") if load_discriminator_weights: try: self.discriminative_network.load_gan_weights(self.srgan_model_) print("Discriminator weights loaded.") except: print("Could not load discriminator weights.") datagen = ImageDataGenerator(rescale=1. / 255) img_width = self.img_width * 4 img_height = self.img_height * 4 early_stop = False iteration = 0 prev_improvement = -1 if save_loss: if pre_train_srgan: loss_history = { 'generator_loss': [], 'val_psnr': [], } elif pre_train_discriminator: loss_history = { 'discriminator_loss': [], 'discriminator_acc': [], } else: loss_history = { 'discriminator_loss': [], 'discriminator_acc': [], 'generator_loss': [], 'val_psnr': [], } y_vgg_dummy = np.zeros((self.batch_size * 2, 3, img_width // 32, img_height // 32)) # 5 Max Pools = 2 ** 5 = 32 print("Training SRGAN network") for i in range(nb_epochs): print() print("Epoch : %d" % (i + 1)) for x in datagen.flow_from_directory(image_dir, class_mode=None, batch_size=self.batch_size, target_size=(img_width, img_height)): try: t1 = time.time() if not pre_train_srgan and not pre_train_discriminator: x_vgg = x.copy() * 255 # VGG input [0 - 255 scale] # resize images x_temp = x.copy() x_temp = x_temp.transpose((0, 2, 3, 1)) x_generator = np.empty( (self.batch_size, self.img_width, self.img_height, 3)) for j in range(self.batch_size): img = gaussian_filter(x_temp[j], sigma=0.1) img = imresize(img, (self.img_width, self.img_height), interp='bicubic') x_generator[j, :, :, :] = img x_generator = x_generator.transpose((0, 3, 1, 2)) if iteration % 50 == 0 and iteration != 0 and not pre_train_discriminator: print("Validation image..") output_image_batch = self.generative_network.get_generator_output( x_generator, self.srgan_model_) if type(output_image_batch) == list: output_image_batch = output_image_batch[0] mean_axis = ( 0, 2, 3) if K.image_dim_ordering() == 'th' else (0, 1, 2) average_psnr = 0.0 print( 'gen img mean :', np.mean(output_image_batch / 255., axis=mean_axis)) print('val img mean :', np.mean(x, axis=mean_axis)) for x_i in range(self.batch_size): average_psnr += psnr( x[x_i], np.clip(output_image_batch[x_i], 0, 255) / 255.) average_psnr /= self.batch_size if save_loss: loss_history['val_psnr'].append(average_psnr) iteration += self.batch_size t2 = time.time() print( "Time required : %0.2f. Average validation PSNR over %d samples = %0.2f" % (t2 - t1, self.batch_size, average_psnr)) for x_i in range(self.batch_size): real_path = "val_images/epoch_%d_iteration_%d_num_%d_real_.png" % ( i + 1, iteration, x_i + 1) generated_path = "val_images/epoch_%d_iteration_%d_num_%d_generated.png" % ( i + 1, iteration, x_i + 1) val_x = x[x_i].copy() * 255. val_x = val_x.transpose((1, 2, 0)) val_x = np.clip(val_x, 0, 255).astype('uint8') output_image = output_image_batch[x_i] output_image = output_image.transpose((1, 2, 0)) output_image = np.clip(output_image, 0, 255).astype('uint8') imsave(real_path, val_x) imsave(generated_path, output_image) ''' Don't train of validation images for now. Note that if nb_epochs > 1, there is a chance that validation images may be used for training purposes as well. In that case, this isn't strictly a validation measure, instead of just a check to see what the network has learned. ''' continue if pre_train_srgan: # Train only generator + vgg network # Use custom bypass_fit to bypass the check for same input and output batch size hist = bypass_fit(self.srgan_model_, [x_generator, x * 255], y_vgg_dummy, batch_size=self.batch_size, nb_epoch=1, verbose=0) sr_loss = hist.history['loss'][0] if save_loss: loss_history['generator_loss'].extend( hist.history['loss']) if prev_improvement == -1: prev_improvement = sr_loss improvement = (prev_improvement - sr_loss) / prev_improvement * 100 prev_improvement = sr_loss iteration += self.batch_size t2 = time.time() print( "Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | " "Generative Loss : %0.2f" % (iteration, nb_images, improvement, t2 - t1, sr_loss)) elif pre_train_discriminator: # Train only discriminator X_pred = self.generative_model_.predict( x_generator, self.batch_size) X = np.concatenate((X_pred, x * 255)) # Using soft and noisy labels if np.random.uniform() > disc_train_flip: # give correct classifications y_gan = [0] * self.batch_size + [ 1 ] * self.batch_size else: # give wrong classifications (noisy labels) y_gan = [1] * self.batch_size + [ 0 ] * self.batch_size y_gan = np.asarray(y_gan, dtype=np.int).reshape(-1, 1) y_gan = to_categorical(y_gan, nb_classes=2) y_gan = smooth_gan_labels(y_gan) hist = self.discriminative_model_.fit( X, y_gan, batch_size=self.batch_size, nb_epoch=1, verbose=0) discriminator_loss = hist.history['loss'][-1] discriminator_acc = hist.history['acc'][-1] if save_loss: loss_history['discriminator_loss'].extend( hist.history['loss']) loss_history['discriminator_acc'].extend( hist.history['acc']) if prev_improvement == -1: prev_improvement = discriminator_loss improvement = (prev_improvement - discriminator_loss ) / prev_improvement * 100 prev_improvement = discriminator_loss iteration += self.batch_size t2 = time.time() print( "Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | " "Discriminator Loss / Acc : %0.4f / %0.2f" % (iteration, nb_images, improvement, t2 - t1, discriminator_loss, discriminator_acc)) else: # Train only discriminator, disable training of srgan self.discriminative_network.set_trainable( self.srgan_model_, value=True) self.generative_network.set_trainable( self.srgan_model_, value=False) # Use custom bypass_fit to bypass the check for same input and output batch size # hist = bypass_fit(self.srgan_model_, [x_generator, x * 255, x_vgg], # [y_gan, y_vgg_dummy], # batch_size=self.batch_size, nb_epoch=1, verbose=0) X_pred = self.generative_model_.predict( x_generator, self.batch_size) X = np.concatenate((X_pred, x * 255)) # Using soft and noisy labels if np.random.uniform() > disc_train_flip: # give correct classifications y_gan = [0] * self.batch_size + [ 1 ] * self.batch_size else: # give wrong classifications (noisy labels) y_gan = [1] * self.batch_size + [ 0 ] * self.batch_size y_gan = np.asarray(y_gan, dtype=np.int).reshape(-1, 1) y_gan = to_categorical(y_gan, nb_classes=2) y_gan = smooth_gan_labels(y_gan) hist1 = self.discriminative_model_.fit( X, y_gan, verbose=0, batch_size=self.batch_size, nb_epoch=1) discriminator_loss = hist1.history['loss'][-1] # Train only generator, disable training of discriminator self.discriminative_network.set_trainable( self.srgan_model_, value=False) self.generative_network.set_trainable( self.srgan_model_, value=True) # Using soft labels y_model = [1] * self.batch_size y_model = np.asarray(y_model, dtype=np.int).reshape(-1, 1) y_model = to_categorical(y_model, nb_classes=2) y_model = smooth_gan_labels(y_model) # Use custom bypass_fit to bypass the check for same input and output batch size hist2 = bypass_fit(self.srgan_model_, [x_generator, x, x_vgg], [y_model, y_vgg_dummy], batch_size=self.batch_size, nb_epoch=1, verbose=0) generative_loss = hist2.history['loss'][0] if save_loss: loss_history['discriminator_loss'].extend( hist1.history['loss']) loss_history['discriminator_acc'].extend( hist1.history['acc']) loss_history['generator_loss'].extend( hist2.history['loss']) if prev_improvement == -1: prev_improvement = discriminator_loss improvement = (prev_improvement - discriminator_loss ) / prev_improvement * 100 prev_improvement = discriminator_loss iteration += self.batch_size t2 = time.time() print( "Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | " "Discriminator Loss : %0.3f | Generative Loss : %0.3f" % (iteration, nb_images, improvement, t2 - t1, discriminator_loss, generative_loss)) if iteration % 1000 == 0 and iteration != 0: print("Saving model weights.") # Save predictive (SR network) weights self._save_model_weights(pre_train_srgan, pre_train_discriminator) self._save_loss_history(loss_history, pre_train_srgan, pre_train_discriminator, save_loss) if iteration >= nb_images: break except KeyboardInterrupt: print("Keyboard interrupt detected. Stopping early.") early_stop = True break iteration = 0 if early_stop: break print("Finished training SRGAN network. Saving model weights.") # Save predictive (SR network) weights self._save_model_weights(pre_train_srgan, pre_train_discriminator) self._save_loss_history(loss_history, pre_train_srgan, pre_train_discriminator, save_loss)
def _test_loop(path, batch_size, datagen, img_height, img_width, iteration, large_img_height, large_img_width, model, total_psnr, prefix, nb_images, normalized): """ :param path: 数据集地址 :param batch_size: 每个iteration生成图片数 :param datagen: 图片增强迭代器 :param img_height: :param img_width: :param iteration: :param large_img_height: :param large_img_width: :param model: 网络模型 :param total_psnr: 总pnsr值 :param prefix: 文件保存位置 :param nb_images: 测试图片数量 :param normalized: 预测模型在训练时是否将图片归一化 :return: """ for x in datagen.flow_from_directory(path, class_mode=None, batch_size=batch_size, target_size=(large_img_width, large_img_height)): t1 = time.time() # resize images x_temp = x.copy() #x_temp = x_temp.transpose((0, 2, 3, 1)) x_generator = np.empty( (batch_size, large_img_width, large_img_height, 3)) for j in range(batch_size): '先对图片进行下取样,再进行bicubic插值' img = imresize(x_temp[j], (img_width, img_height)) img = imresize(img, (large_img_width, large_img_height), interp='bicubic') '归一化情况下须要除以255' if normalized: x_generator[j, :, :, :] = img / 255. else: x_generator[j, :, :, :] = img output_image_batch = model.predict_on_batch(x_generator) average_psnr = 0.0 for x_i in range(batch_size): if normalized: average_psnr += psnr(x[x_i], output_image_batch[x_i]) else: average_psnr += psnr(x[x_i], output_image_batch[x_i] / 255.) total_psnr += average_psnr average_psnr /= batch_size iteration += batch_size t2 = time.time() print( "Time required : %0.2f. Average validation PSNR over %d samples = %0.2f" % (t2 - t1, batch_size, average_psnr)) for x_i in range(batch_size): '保存验证集中的图片' real_path = base_test_images + prefix + "_iteration_%d_num_%d_real_.png" % ( iteration, x_i + 1) bicubic_path = base_test_images + prefix + "_iteration_%d_num_%d_bicubic.png" % ( iteration, x_i + 1) generated_path = base_test_images + prefix + "_iteration_%d_num_%d_generated.png" % ( iteration, x_i + 1) val_x = x[x_i].copy() * 255. val_x = np.clip(val_x, 0, 255).astype('uint8') if normalized: input_img = x_generator[x_i].copy() * 255 else: input_img = x_generator[x_i].copy() input_img = np.clip(input_img, 0, 255).astype('uint8') if normalized: output_image = output_image_batch[x_i] * 255 else: output_image = output_image_batch[x_i] output_image = np.clip(output_image, 0, 255).astype('uint8') imsave(real_path, val_x) imsave(bicubic_path, input_img) imsave(generated_path, output_image) if iteration >= nb_images: break return total_psnr
def test_individul(prefix='Set5', scale=4, mode='rgb', normalized=True): pic_path = os.path.join(set5_path[0:len(set5_path) - 6], prefix) pic_path = os.path.join(pic_path, prefix) total_psnr_bicubic = 0 total_psnr_generated = 0 total_ssim_bicubic = 0 total_ssim_generated = 0 for file in os.listdir(pic_path): real_rgb = imread(os.path.join(pic_path, file)) '图片数据预处理' img_shape = real_rgb.shape img_width = int(img_shape[0] / scale) * scale img_height = int(img_shape[1] / scale) * scale real_rgb = real_rgb[0:img_width, 0:img_height] '可能出现灰度图的情况,通道只有2' if len(img_shape) == 2: real_temp = np.empty((img_width, img_height, 3)) for l in range(3): real_temp[:, :, l] = real_rgb real_rgb = real_temp.copy() '对读入的RGB图像进行下取样,并对低分辨率图像进行bicubic插值' lr_rgb = imresize(real_rgb, 1 / scale, interp='bicubic') bicubic_rgb = imresize(lr_rgb, (img_width, img_height), interp='bicubic') if mode == 'ycrcb': real_ycbcr = rgb2ycbcr(real_rgb) bicubic_ycbcr = rgb2ycbcr(bicubic_rgb) lr_test = np.empty( (1, bicubic_ycbcr.shape[0], bicubic_ycbcr.shape[1], 1)) lr_test[0, :, :, 0] = bicubic_ycbcr[:, :, 0] elif mode == 'rgb': lr_test = np.empty((1, bicubic_rgb.shape[0], bicubic_rgb.shape[1], bicubic_rgb.shape[2])) lr_test[0, :, :, :] = bicubic_rgb '读取训练后的SR模型' SRDenseNet_Test = SRModel(img_width=lr_test.shape[1], img_height=lr_test.shape[2], mode=mode) SRDenseNet_Test.build_model(load_weights=True) if normalized: lr_test /= 255 sr_test = SRDenseNet_Test.model.predict(lr_test) if normalized: sr_test = sr_test * 255 if mode == 'ycrcb': SR_ycrcb = np.empty((sr_test.shape[1], sr_test.shape[2], 3)) SR_ycrcb[:, :, 0] = sr_test[0, :, :, 0] SR_ycrcb[:, :, 1] = bicubic_ycbcr[:, :, 1] SR_ycrcb[:, :, 2] = bicubic_ycbcr[:, :, 2] SR_rgb = ycbcr2rgb(SR_ycrcb) elif mode == 'rgb': SR_rgb = sr_test[0] SR_rgb = np.clip(SR_rgb, 0, 255) '保存验证集中的图片' real_img = np.clip(real_rgb, 0, 255).astype('uint8') bicubic_img = np.clip(bicubic_rgb, 0, 255).astype('uint8') output_image = np.clip(SR_rgb, 0, 255).astype('uint8') PSNR_bicubic = psnr(real_img.astype('float'), bicubic_img.astype('float')) PSNR_generated = psnr(real_img.astype('float'), output_image.astype('float')) real_ycbcr = rgb2ycbcr(real_img) bicubic_ycbcr = rgb2ycbcr(bicubic_img) output_ycbcr = rgb2ycbcr(output_image) SSIM_bicubic = compute_ssim(real_ycbcr[:, :, 0], bicubic_ycbcr[:, :, 0]) SSIM_generated = compute_ssim(real_ycbcr[:, :, 0], output_ycbcr[:, :, 0]) print( '%s_PSNR/SSIM: bicubic %0.2f/%0.2f ; SRDenseNet %0.2f/%0.2f' % (file[0:len(file) - 4], PSNR_bicubic, SSIM_bicubic, PSNR_generated, SSIM_generated)) if not os.path.exists(base_test_images + prefix): os.makedirs(base_test_images + prefix) real_path = base_test_images + prefix + '/' + prefix + "_%s_real.png" % ( file[0:len(file) - 4]) bicubic_img_path = base_test_images + prefix + '/' + prefix + "_%s_bicubic_%0.2f|%0.2f.png" % ( file[0:len(file) - 4], PSNR_bicubic, SSIM_bicubic) generated_path = base_test_images + prefix + '/' + prefix + "_%s_generated_%0.2f|%0.2f.png" % ( file[0:len(file) - 4], PSNR_generated, SSIM_generated) imsave(real_path, real_img) imsave(bicubic_img_path, bicubic_img) imsave(generated_path, output_image) total_psnr_bicubic += PSNR_bicubic total_psnr_generated += PSNR_generated total_ssim_bicubic += SSIM_bicubic total_ssim_generated += SSIM_generated l = len(os.listdir(pic_path)) print( 'Average_PSNR/SSIM: bicubic %0.2f/%0.2f ; VDSR_new %0.2f/%0.2f' % (total_psnr_bicubic / l, total_ssim_bicubic / l, total_psnr_generated / l, total_ssim_generated / l))
def run_epoch(self, epoch, dataloader, logimage=False, isTrain=True): # For details see training. psnr_value = 0 ssim_value = 0 loss_value = 0 if not isTrain: valid_images = [] for index, all_data in enumerate(dataloader, 0): self.optimizer.zero_grad() ( Ft_p, I0, IFrame, I1, g_I0_F_t_0, g_I1_F_t_1, FlowBackWarp_I0_F_1_0, FlowBackWarp_I1_F_0_1, F_1_0, F_0_1, ) = self.slomo(all_data, pred_only=False, isTrain=isTrain) if (not isTrain) and logimage: if index % self.args.logimagefreq == 0: valid_images.append( 255.0 * I0.cpu()[0] .resize_(1, 1, self.args.data_h, self.args.data_w) .repeat(1, 3, 1, 1) ) valid_images.append( 255.0 * IFrame.cpu()[0] .resize_(1, 1, self.args.data_h, self.args.data_w) .repeat(1, 3, 1, 1) ) valid_images.append( 255.0 * I1.cpu()[0] .resize_(1, 1, self.args.data_h, self.args.data_w) .repeat(1, 3, 1, 1) ) valid_images.append( 255.0 * Ft_p.cpu()[0] .resize_(1, 1, self.args.data_h, self.args.data_w) .repeat(1, 3, 1, 1) ) # loss loss = self.supervisedloss( Ft_p, IFrame, I0, I1, g_I0_F_t_0, g_I1_F_t_1, FlowBackWarp_I0_F_1_0, FlowBackWarp_I1_F_0_1, F_1_0, F_0_1, ) if isTrain: loss.backward() self.optimizer.step() self.scheduler.step() loss_value += loss.item() # metrics psnr_value += psnr(Ft_p, IFrame, outputTensor=False) ssim_value += ssim(Ft_p, IFrame, outputTensor=False) name_loss = "TrainLoss" if isTrain else "ValLoss" itr = int(index + epoch * (len(dataloader))) if self.comet_exp is not None: self.comet_exp.log_metric( "PSNR", psnr_value / len(dataloader), step=itr, epoch=epoch ) self.comet_exp.log_metric( "SSIM", ssim_value / len(dataloader), step=itr, epoch=epoch ) self.comet_exp.log_metric( name_loss, loss_value / len(dataloader), step=itr, epoch=epoch ) if logimage: upload_images( valid_images, epoch, exp=self.comet_exp, im_per_row=4, rows_per_log=int(len(valid_images) / 4), ) print( " Loss: %0.6f Iterations: %4d/%4d ValPSNR: %0.4f ValSSIM: %0.4f " % ( loss_value / len(dataloader), index, len(dataloader), psnr_value / len(dataloader), ssim_value / len(dataloader), ) ) return ( (psnr_value / len(dataloader)), (ssim_value / len(dataloader)), (loss_value / len(dataloader)), )
def _train_model(self, image_dir, nb_images=50000, nb_epochs=20, pre_train=False, load_generative_weights=False, load_discriminator_weights=False, save_loss=True): assert self.img_width >= 16, "Minimum image width must be at least 16" assert self.img_height >= 16, "Minimum image height must be at least 16" if not pre_train: if load_generative_weights: self.generative_model_.load_weights( self.generative_network.sr_weights_path) if load_discriminator_weights: self.discriminative_network.load_gan_weights(self.srgan_model_) datagen = ImageDataGenerator(rescale=1. / 255) img_width = self.img_width * 4 img_height = self.img_height * 4 early_stop = False iteration = 0 prev_improvement = -1 if save_loss: if pre_train: loss_history = { 'generator_loss': [], 'val_psnr': [], } else: loss_history = { 'discriminator_loss': [], 'generator_loss': [], 'val_psnr': [], } y_vgg_dummy = np.zeros((self.batch_size * 2, 3, img_width // 32, img_height // 32)) # 5 Max Pools = 2 ** 5 = 32 if not pre_train: y_gan = [0] * self.batch_size + [1] * self.batch_size y_gan = np.asarray(y_gan, dtype=np.float32).reshape(-1, 1) print("Training SRGAN network") for i in range(nb_epochs): print() print("Epoch : %d" % (i + 1)) for x in datagen.flow_from_directory(image_dir, class_mode=None, batch_size=self.batch_size, target_size=(img_width, img_height)): try: t1 = time.time() if not pre_train: x_vgg = x.copy() * 255 # VGG input [0 - 255 scale] # resize images x_temp = x.copy() x_temp = x_temp.transpose((0, 2, 3, 1)) x_generator = np.empty( (self.batch_size, self.img_width, self.img_height, 3)) for j in range(self.batch_size): img = gaussian_filter(x_temp[j], sigma=0.5) img = imresize(img, (self.img_width, self.img_height)) x_generator[j, :, :, :] = img x_generator = x_generator.transpose((0, 3, 1, 2)) if iteration % 50 == 0 and iteration != 0: print("Validation image..") output_image_batch = self.generative_network.get_generator_output( x_generator, self.srgan_model_) #output_image_batch = output_image_batch[0] average_psnr = 0.0 for x_i in range(self.batch_size): average_psnr += psnr( x[x_i], np.clip(output_image_batch[x_i], 0, 255) / 255.) average_psnr /= self.batch_size if save_loss: loss_history['val_psnr'].append(average_psnr) iteration += self.batch_size t2 = time.time() print( "Time required : %0.2f. Average validation PSNR over %d samples = %0.2f" % (t2 - t1, self.batch_size, average_psnr)) for x_i in range(self.batch_size): real_path = "val_images/epoch_%d_iteration_%d_num_%d_real_.png" % ( i + 1, iteration, x_i + 1) generated_path = "val_images/epoch_%d_iteration_%d_num_%d_generated.png" % ( i + 1, iteration, x_i + 1) val_x = x[x_i].copy() * 255. val_x = val_x.transpose((1, 2, 0)) val_x = np.clip(val_x, 0, 255).astype('uint8') # print('min = ', np.min(output_image_batch[x_i])) # print('max = ', np.max(output_image_batch[x_i])) # print('mean = ', np.mean(output_image_batch[x_i])) output_image = output_image_batch[x_i] output_image = output_image.transpose((1, 2, 0)) output_image = np.clip(output_image, 0, 255).astype('uint8') imsave(real_path, val_x) imsave(generated_path, output_image) ''' Don't train of validation images for now. Note that if nb_epochs > 1, there is a chance that validation images may be used for training purposes as well. In that case, this isn't strictly a validation measure, instead of just a check to see what the network has learned. ''' continue if pre_train: # Train only generator + vgg network # Use custom bypass_fit to bypass the check for same input and output batch size hist = bypass_fit(self.srgan_model_, [x_generator, x * 255], y_vgg_dummy, batch_size=self.batch_size, nb_epoch=1, verbose=0) sr_loss = hist.history['loss'][0] if save_loss: loss_history['generator_loss'].append(sr_loss) if prev_improvement == -1: prev_improvement = sr_loss improvement = (prev_improvement - sr_loss) / prev_improvement * 100 prev_improvement = sr_loss iteration += self.batch_size t2 = time.time() print( "Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | " "Generative Loss : %0.3f" % (iteration, nb_images, improvement, t2 - t1, sr_loss)) else: # Train only discriminator, disable training of srgan self.discriminative_network.set_trainable( self.srgan_model_, value=True) self.generative_network.set_trainable( self.srgan_model_, value=False) # Use custom bypass_fit to bypass the check for same input and output batch size hist = bypass_fit(self.srgan_model_, [x_generator, x * 255, x_vgg], [y_gan, y_vgg_dummy], batch_size=self.batch_size, nb_epoch=1, verbose=0) discriminator_loss = hist.history['loss'][0] # Train only generator, disable training of discriminator self.discriminative_network.set_trainable( self.srgan_model_, value=False) self.generative_network.set_trainable( self.srgan_model_, value=True) # Use custom bypass_fit to bypass the check for same input and output batch size hist = bypass_fit(self.srgan_model_, [x_generator, x * 255, x_vgg], [y_gan, y_vgg_dummy], batch_size=self.batch_size, nb_epoch=1, verbose=0) generative_loss = hist.history['loss'][0] if save_loss: loss_history['discriminator_loss'].append( discriminator_loss) loss_history['generator_loss'].append( generative_loss) if prev_improvement == -1: prev_improvement = discriminator_loss improvement = (prev_improvement - discriminator_loss ) / prev_improvement * 100 prev_improvement = discriminator_loss iteration += self.batch_size t2 = time.time() print( "Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | " "Discriminator Loss : %0.3f | Generative Loss : %0.3f" % (iteration, nb_images, improvement, t2 - t1, discriminator_loss, generative_loss)) if iteration % 1000 == 0 and iteration != 0: print("Saving model weights.") # Save predictive (SR network) weights self.generative_model_.save_weights( self.generative_network.sr_weights_path, overwrite=True) if not pre_train: # Save GAN (discriminative network) weights self.discriminative_network.save_gan_weights( self.srgan_model_) if save_loss: print("Saving loss history") if pre_train: with open('pretrain losses.json', 'w') as f: json.dump(loss_history, f) else: with open('fulltrain losses.json', 'w') as f: json.dump(loss_history, f) print("Saved loss history") if iteration >= nb_images: break except KeyboardInterrupt: print("Keyboard interrupt detected. Stopping early.") early_stop = True break iteration = 0 if early_stop: break print("Finished training SRGAN network. Saving model weights.") # Save predictive (SR network) weights self.generative_model_.save_weights( self.generative_network.sr_weights_path) if not pre_train: # Save GAN (discriminative network) weights self.discriminative_network.save_gan_weights(self.srgan_model_) print("Weights saved in 'weights' directory") if save_loss: print("Saving loss history") if pre_train: with open('pretrain losses.json', 'w') as f: json.dump(loss_history, f) else: with open('fulltrain losses.json', 'w') as f: json.dump(loss_history, f) print("Saved loss history")
def _train_model(self, image_dir, nb_images=80000, nb_epochs=10, pre_train_srgan=False, pre_train_discriminator=False, load_generative_weights=False, load_discriminator_weights=False, save_loss=True, disc_train_flip=0.1): assert self.img_width >= 16, "Minimum image width must be at least 16" assert self.img_height >= 16, "Minimum image height must be at least 16" if load_generative_weights: try: self.generative_model_.load_weights(self.generative_network.sr_weights_path) print("Generator weights loaded.") except: print("Could not load generator weights.") if load_discriminator_weights: try: self.discriminative_network.load_gan_weights(self.srgan_model_) print("Discriminator weights loaded.") except: print("Could not load discriminator weights.") datagen = ImageDataGenerator(rescale=1. / 255) img_width = self.img_width * 4 img_height = self.img_height * 4 early_stop = False iteration = 0 prev_improvement = -1 if save_loss: if pre_train_srgan: loss_history = {'generator_loss' : [], 'val_psnr' : [], } elif pre_train_discriminator: loss_history = {'discriminator_loss' : [], 'discriminator_acc' : [], } else: loss_history = {'discriminator_loss' : [], 'discriminator_acc' : [], 'generator_loss' : [], 'val_psnr': [], } y_vgg_dummy = np.zeros((self.batch_size * 2, 3, img_width // 32, img_height // 32)) # 5 Max Pools = 2 ** 5 = 32 print("Training SRGAN network") for i in range(nb_epochs): print() print("Epoch : %d" % (i + 1)) for x in datagen.flow_from_directory(image_dir, class_mode=None, batch_size=self.batch_size, target_size=(img_width, img_height)): try: t1 = time.time() if not pre_train_srgan and not pre_train_discriminator: x_vgg = x.copy() * 255 # VGG input [0 - 255 scale] # resize images x_temp = x.copy() x_temp = x_temp.transpose((0, 2, 3, 1)) x_generator = np.empty((self.batch_size, self.img_width, self.img_height, 3)) for j in range(self.batch_size): img = gaussian_filter(x_temp[j], sigma=0.1) img = imresize(img, (self.img_width, self.img_height), interp='bicubic') x_generator[j, :, :, :] = img x_generator = x_generator.transpose((0, 3, 1, 2)) if iteration % 50 == 0 and iteration != 0 and not pre_train_discriminator: print("Validation image..") output_image_batch = self.generative_network.get_generator_output(x_generator, self.srgan_model_) if type(output_image_batch) == list: output_image_batch = output_image_batch[0] mean_axis = (0, 2, 3) if K.image_dim_ordering() == 'th' else (0, 1, 2) average_psnr = 0.0 print('gen img mean :', np.mean(output_image_batch / 255., axis=mean_axis)) print('val img mean :', np.mean(x, axis=mean_axis)) for x_i in range(self.batch_size): average_psnr += psnr(x[x_i], np.clip(output_image_batch[x_i], 0, 255) / 255.) average_psnr /= self.batch_size if save_loss: loss_history['val_psnr'].append(average_psnr) iteration += self.batch_size t2 = time.time() print("Time required : %0.2f. Average validation PSNR over %d samples = %0.2f" % (t2 - t1, self.batch_size, average_psnr)) for x_i in range(self.batch_size): real_path = "val_images/epoch_%d_iteration_%d_num_%d_real_.png" % (i + 1, iteration, x_i + 1) generated_path = "val_images/epoch_%d_iteration_%d_num_%d_generated.png" % (i + 1, iteration, x_i + 1) val_x = x[x_i].copy() * 255. val_x = val_x.transpose((1, 2, 0)) val_x = np.clip(val_x, 0, 255).astype('uint8') output_image = output_image_batch[x_i] output_image = output_image.transpose((1, 2, 0)) output_image = np.clip(output_image, 0, 255).astype('uint8') imsave(real_path, val_x) imsave(generated_path, output_image) ''' Don't train of validation images for now. Note that if nb_epochs > 1, there is a chance that validation images may be used for training purposes as well. In that case, this isn't strictly a validation measure, instead of just a check to see what the network has learned. ''' continue if pre_train_srgan: # Train only generator + vgg network # Use custom bypass_fit to bypass the check for same input and output batch size hist = bypass_fit(self.srgan_model_, [x_generator, x * 255], y_vgg_dummy, batch_size=self.batch_size, nb_epoch=1, verbose=0) sr_loss = hist.history['loss'][0] if save_loss: loss_history['generator_loss'].extend(hist.history['loss']) if prev_improvement == -1: prev_improvement = sr_loss improvement = (prev_improvement - sr_loss) / prev_improvement * 100 prev_improvement = sr_loss iteration += self.batch_size t2 = time.time() print("Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | " "Generative Loss : %0.2f" % (iteration, nb_images, improvement, t2 - t1, sr_loss)) elif pre_train_discriminator: # Train only discriminator X_pred = self.generative_model_.predict(x_generator, self.batch_size) X = np.concatenate((X_pred, x * 255)) # Using soft and noisy labels if np.random.uniform() > disc_train_flip: # give correct classifications y_gan = [0] * self.batch_size + [1] * self.batch_size else: # give wrong classifications (noisy labels) y_gan = [1] * self.batch_size + [0] * self.batch_size y_gan = np.asarray(y_gan, dtype=np.int).reshape(-1, 1) y_gan = to_categorical(y_gan, nb_classes=2) y_gan = smooth_gan_labels(y_gan) hist = self.discriminative_model_.fit(X, y_gan, batch_size=self.batch_size, nb_epoch=1, verbose=0) discriminator_loss = hist.history['loss'][-1] discriminator_acc = hist.history['acc'][-1] if save_loss: loss_history['discriminator_loss'].extend(hist.history['loss']) loss_history['discriminator_acc'].extend(hist.history['acc']) if prev_improvement == -1: prev_improvement = discriminator_loss improvement = (prev_improvement - discriminator_loss) / prev_improvement * 100 prev_improvement = discriminator_loss iteration += self.batch_size t2 = time.time() print("Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | " "Discriminator Loss / Acc : %0.4f / %0.2f" % (iteration, nb_images, improvement, t2 - t1, discriminator_loss, discriminator_acc)) else: # Train only discriminator, disable training of srgan self.discriminative_network.set_trainable(self.srgan_model_, value=True) self.generative_network.set_trainable(self.srgan_model_, value=False) # Use custom bypass_fit to bypass the check for same input and output batch size # hist = bypass_fit(self.srgan_model_, [x_generator, x * 255, x_vgg], # [y_gan, y_vgg_dummy], # batch_size=self.batch_size, nb_epoch=1, verbose=0) X_pred = self.generative_model_.predict(x_generator, self.batch_size) X = np.concatenate((X_pred, x * 255)) # Using soft and noisy labels if np.random.uniform() > disc_train_flip: # give correct classifications y_gan = [0] * self.batch_size + [1] * self.batch_size else: # give wrong classifications (noisy labels) y_gan = [1] * self.batch_size + [0] * self.batch_size y_gan = np.asarray(y_gan, dtype=np.int).reshape(-1, 1) y_gan = to_categorical(y_gan, nb_classes=2) y_gan = smooth_gan_labels(y_gan) hist1 = self.discriminative_model_.fit(X, y_gan, verbose=0, batch_size=self.batch_size, nb_epoch=1) discriminator_loss = hist1.history['loss'][-1] # Train only generator, disable training of discriminator self.discriminative_network.set_trainable(self.srgan_model_, value=False) self.generative_network.set_trainable(self.srgan_model_, value=True) # Using soft labels y_model = [1] * self.batch_size y_model = np.asarray(y_model, dtype=np.int).reshape(-1, 1) y_model = to_categorical(y_model, nb_classes=2) y_model = smooth_gan_labels(y_model) # Use custom bypass_fit to bypass the check for same input and output batch size hist2 = bypass_fit(self.srgan_model_, [x_generator, x, x_vgg], [y_model, y_vgg_dummy], batch_size=self.batch_size, nb_epoch=1, verbose=0) generative_loss = hist2.history['loss'][0] if save_loss: loss_history['discriminator_loss'].extend(hist1.history['loss']) loss_history['discriminator_acc'].extend(hist1.history['acc']) loss_history['generator_loss'].extend(hist2.history['loss']) if prev_improvement == -1: prev_improvement = discriminator_loss improvement = (prev_improvement - discriminator_loss) / prev_improvement * 100 prev_improvement = discriminator_loss iteration += self.batch_size t2 = time.time() print("Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | " "Discriminator Loss : %0.3f | Generative Loss : %0.3f" % (iteration, nb_images, improvement, t2 - t1, discriminator_loss, generative_loss)) if iteration % 1000 == 0 and iteration != 0: print("Saving model weights.") # Save predictive (SR network) weights self._save_model_weights(pre_train_srgan, pre_train_discriminator) self._save_loss_history(loss_history, pre_train_srgan, pre_train_discriminator, save_loss) if iteration >= nb_images: break except KeyboardInterrupt: print("Keyboard interrupt detected. Stopping early.") early_stop = True break iteration = 0 if early_stop: break print("Finished training SRGAN network. Saving model weights.") # Save predictive (SR network) weights self._save_model_weights(pre_train_srgan, pre_train_discriminator) self._save_loss_history(loss_history, pre_train_srgan, pre_train_discriminator, save_loss)
def train_model(self, image_dir, nb_images=50000, nb_epochs=1): datagen = ImageDataGenerator(rescale=1. / 255) img_width = self.img_width * 4 img_height = self.img_height * 4 early_stop = False iteration = 0 prev_improvement = -1 print("Training SR ResNet network") for i in range(nb_epochs): print() print("Epoch : %d" % (i + 1)) for x in datagen.flow_from_directory(image_dir, class_mode=None, batch_size=self.batch_size, target_size=(img_width, img_height)): try: t1 = time.time() # resize images x_temp = x.copy() # x_temp = x_temp.transpose((0, 2, 3, 1)) x_generator = np.empty( (self.batch_size, self.img_width, self.img_height, 3)) for j in range(self.batch_size): img = gaussian_filter(x_temp[j], sigma=0.5) img = imresize(img, (self.img_width, self.img_height)) x_generator[j, :, :, :] = img # x_generator = x_generator.transpose((0, 3, 1, 2)) if iteration % 50 == 0 and iteration != 0: print("Random Validation image..") output_image_batch = self.model.predict_on_batch( x_generator) print("Pred Max / Min: %0.2f / %0.2f" % (output_image_batch.max(), output_image_batch.min())) average_psnr = 0.0 for x_i in range(self.batch_size): average_psnr += psnr( x[x_i], output_image_batch[x_i] / 255.) average_psnr /= self.batch_size iteration += self.batch_size t2 = time.time() print( "Time required : %0.2f. Average validation PSNR over %d samples = %0.2f" % (t2 - t1, self.batch_size, average_psnr)) for x_i in range(self.batch_size): real_path = base_val_images_path + "epoch_%d_iteration_%d_num_%d_real_.png" % \ (i + 1, iteration, x_i + 1) generated_path = base_val_images_path + \ "epoch_%d_iteration_%d_num_%d_generated.png" % (i + 1, iteration, x_i + 1) val_x = x[x_i].copy() * 255. #val_x = val_x.transpose((1, 2, 0)) val_x = np.clip(val_x, 0, 255).astype('uint8') output_image = output_image_batch[x_i] #output_image = output_image.transpose((1, 2, 0)) output_image = np.clip(output_image, 0, 255).astype('uint8') imsave(real_path, val_x) imsave(generated_path, output_image) ''' Don't train of validation images for now. Note that if nb_epochs > 1, there is a chance that validation images may be used for training purposes as well. In that case, this isn't strictly a validation measure, instead of just a check to see what the network has learned. ''' continue hist = self.model.fit(x_generator, x * 255, batch_size=self.batch_size, epochs=1, verbose=0) psnr_loss_val = hist.history['PSNRLoss'][0] if prev_improvement == -1: prev_improvement = psnr_loss_val improvement = (prev_improvement - psnr_loss_val) / prev_improvement * 100 prev_improvement = psnr_loss_val iteration += self.batch_size t2 = time.time() print( "Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | " "PSNR : %0.3f" % (iteration, nb_images, improvement, t2 - t1, psnr_loss_val)) if iteration % 1000 == 0 and iteration != 0: print("Saving weights") self.model.save_weights(self.weights_path, overwrite=True) if iteration >= nb_images: break except KeyboardInterrupt: print("Keyboard interrupt detected. Stopping early.") early_stop = True break iteration = 0 if early_stop: break print("Finished training SRGAN network. Saving model weights.")
def train_model(self, image_dir, nb_images=50000, nb_epochs=1): datagen = ImageDataGenerator(rescale=1. / 255) img_width = self.img_width * 4 img_height = self.img_height * 4 early_stop = False iteration = 0 prev_improvement = -1 print("Training SR ResNet network") for i in range(nb_epochs): print() print("Epoch : %d" % (i + 1)) for x in datagen.flow_from_directory(image_dir, class_mode=None, batch_size=self.batch_size, target_size=(img_width, img_height)): try: t1 = time.time() # resize images x_temp = x.copy() x_temp = x_temp.transpose((0, 2, 3, 1)) x_generator = np.empty((self.batch_size, self.img_width, self.img_height, 3)) for j in range(self.batch_size): img = gaussian_filter(x_temp[j], sigma=0.5) img = imresize(img, (self.img_width, self.img_height)) x_generator[j, :, :, :] = img x_generator = x_generator.transpose((0, 3, 1, 2)) if iteration % 50 == 0 and iteration != 0 : print("Random Validation image..") output_image_batch = self.model.predict_on_batch(x_generator) print("Pred Max / Min: %0.2f / %0.2f" % (output_image_batch.max(), output_image_batch.min())) average_psnr = 0.0 for x_i in range(self.batch_size): average_psnr += psnr(x[x_i], output_image_batch[x_i] / 255.) average_psnr /= self.batch_size iteration += self.batch_size t2 = time.time() print("Time required : %0.2f. Average validation PSNR over %d samples = %0.2f" % (t2 - t1, self.batch_size, average_psnr)) for x_i in range(self.batch_size): real_path = base_val_images_path + "epoch_%d_iteration_%d_num_%d_real_.png" % \ (i + 1, iteration, x_i + 1) generated_path = base_val_images_path + \ "epoch_%d_iteration_%d_num_%d_generated.png" % (i + 1, iteration, x_i + 1) val_x = x[x_i].copy() * 255. val_x = val_x.transpose((1, 2, 0)) val_x = np.clip(val_x, 0, 255).astype('uint8') output_image = output_image_batch[x_i] output_image = output_image.transpose((1, 2, 0)) output_image = np.clip(output_image, 0, 255).astype('uint8') imsave(real_path, val_x) imsave(generated_path, output_image) ''' Don't train of validation images for now. Note that if nb_epochs > 1, there is a chance that validation images may be used for training purposes as well. In that case, this isn't strictly a validation measure, instead of just a check to see what the network has learned. ''' continue hist = self.model.fit(x_generator, x * 255, batch_size=self.batch_size, nb_epoch=1, verbose=0) psnr_loss_val = hist.history['PSNRLoss'][0] if prev_improvement == -1: prev_improvement = psnr_loss_val improvement = (prev_improvement - psnr_loss_val) / prev_improvement * 100 prev_improvement = psnr_loss_val iteration += self.batch_size t2 = time.time() print("Iter : %d / %d | Improvement : %0.2f percent | Time required : %0.2f seconds | " "PSNR : %0.3f" % (iteration, nb_images, improvement, t2 - t1, psnr_loss_val)) if iteration % 1000 == 0 and iteration != 0: print("Saving weights") self.model.save_weights(self.weights_path, overwrite=True) if iteration >= nb_images: break except KeyboardInterrupt: print("Keyboard interrupt detected. Stopping early.") early_stop = True break iteration = 0 if early_stop: break print("Finished training SRGAN network. Saving model weights.")