def compare_folder(origin, codes, res): psnr = [] ssim = [] total_pixels = 0 for filename in os.listdir(origin): original_i = os.path.join(origin, filename) res_i = os.path.join(res, filename) psnr.append(metric.psnr(original_i, res_i)) ssim.append(metric.msssim(original_i, res_i)) total_pixels += utils.get_pixels(original_i) print(psnr[-1], ssim[-1]) total_size = utils.get_size_folder(codes) bpp = total_size / total_pixels print(bpp, mean(psnr), mean(ssim)) return bpp, mean(psnr), mean(ssim)
def get_ssim(res_path, jpeg=False): ssim = [] for i in range(24): j = i + 1 if j >= 10: n_id = str(j) else: n_id = '0' + str(j) original = '/home/williamchen/Dataset/Kodak/kodim' + n_id + '.png' compared = '{}/{}/00.png'.format(res_path, n_id) if jpeg: compared = '{}/{}/00.jpg'.format(res_path, n_id) ssim.append(metric.msssim(compared, original)) return ssim
def test_valid(model_path, version, root): os.system('mkdir -p codes_val/{}'.format(version)) os.system('mkdir -p res_val/{}'.format(version)) bpp = [] psnr = [] ssim = [] load_model(model_path) for filename in os.listdir(root): original = os.path.join(root, filename) #filename = filename[:-4] codes_path = 'codes_val/{}'.format(version) output_path = 'res_val/{}/{}'.format(version, filename) os.system('mkdir -p {}'.format(output_path)) encode_image_with_padding(root, filename, codes_path) codes = codes_path + '/' + filename[:-4] + '.npz' filename = filename[:-4] decode_image_with_padding(codes_path, output_path, filename) compared = output_path + '/' + filename + '.png' bpp.append(utils.calc_bpp(codes, original)) psnr.append(metric.psnr(original, compared)) ssim.append(metric.msssim(compared, original)) return mean(bpp), mean(psnr), mean(ssim)
def test_validation(model_path, version, root): os.system('mkdir -p codes_val/{}'.format(version)) os.system('mkdir -p res_val/{}'.format(version)) bpp = [] psnr = [] ssim = [] for filename in os.listdir(root): original = os.path.join(root, filename) filename = filename[:-4] os.system('mkdir -p res_val/{}/{}'.format(version, filename)) os.system( 'python encoder.py --model {}/encoder.pth --input {} --output codes_val/{}/{} ' .format(model_path, original, version, filename)) os.system( 'python decoder.py --model {}/decoder.pth --input codes_val/{}/{}.npz --output res_val/{}/{} ' .format(model_path, version, filename, version, filename)) codes = 'codes_val/{}/{}.npz'.format(version, filename) compared = 'res_val/{}/{}/00.png'.format(version, filename) bpp.append(utils.calc_bpp(codes, original)) psnr.append(metric.psnr(original, compared)) ssim.append(metric.msssim(compared, original)) return mean(bpp), mean(psnr), mean(ssim)
def get_ms_ssim(original, compared): return msssim(as_img_array(original), as_img_array(compared))
def train(): if not os.path.exists(args.checkpoint_dir): # shutil.rmtree(args.checkpoint_dir) os.makedirs(args.checkpoint_dir) log_name = os.path.join(args.checkpoint_dir, 'params.log') if os.path.exists(log_name): print('remove file:%s' % log_name) os.remove(log_name) params = open(log_name, 'w') for arg in vars(args): str_ = '%s: %s.\n' % (arg, getattr(args, arg)) print(str_) params.write(str_) params.close() tf.logging.set_verbosity(tf.logging.INFO) # tf Graph input (only pictures) if args.data_set.lower() == 'celeba': data_glob = imgs_path = args.img_path + '/*.png' print(imgs_path) ip_train = inputpipeline.InputPipeline( inputpipeline.get_dataset(data_glob), args.patch_size, batch_size=args.batch_size, shuffle=True, num_preprocess_threads=6, num_crops_per_img=6) X = ip_train.get_batch() # Construct model #encoder_op = analysis_transform(X, 64) encoder_op, mean, var = analysis_transform(X, 64) if args.split == 'None': X_pred = synthesis_transform(encoder_op, 64) else: X_pred = synthesis_transform(mean, 64) X_pred2 = synthesis_transform(encoder_op, 64) # Define loss and optimizer, minimize the squared error mse_loss = tf.reduce_mean(tf.squared_difference(255 * X, 255 * X_pred)) msssim_loss = ms_ssim.MultiScaleSSIM(X * 255, X_pred * 255, data_format='NHWC') if args.loss1 =="mse": # mse loss d1 = tf.reduce_mean(tf.squared_difference( X, X_pred)) elif args.loss1 == 'ssim': d1 = tf.reduce_mean(ssim_matrix.ssim(X * 255, (X - X_pred) * 255, X_pred, max_val=255, mode='train',compensation=1)) else: print('error invalid loss1') return -1 if args.split != 'None': if args.loss2 =="mse": # mse loss d2 = tf.reduce_mean(tf.squared_difference(X_pred, X_pred2)) elif args.loss2 == 'ssim': d2 = tf.reduce_mean(ssim_matrix.ssim(X_pred * 255, (X_pred - X_pred2) * 255, X_pred2, max_val=255, mode='train',compensation=1)) # KL loss kl_div_loss = 1 + var - tf.square(mean) - tf.exp(var) kl_div_loss = -0.5 * tf.reduce_sum(kl_div_loss, 1) kl_div_loss = tf.reduce_mean(kl_div_loss) # total loss if args.split != 'None': train_loss = args.lambda1 * d1 + args.lambda2 * d2 + kl_div_loss else: train_loss = args.lambda1 * d1 + kl_div_loss step = tf.train.create_global_step() if args.finetune != 'None': learning_rate = 0.00001 else: learning_rate = 0.0001 main_lr = tf.train.AdamOptimizer(learning_rate) optimizer = main_lr.minimize(train_loss, global_step=step) tf.summary.scalar("loss", train_loss) #tf.summary.scalar("bpp", bpp) tf.summary.scalar("mse", mse_loss) logged_tensors = [ tf.identity(train_loss, name="train_loss"), # tf.identity(bpp, name="train_bpp"), tf.identity(msssim_loss, name="ms-ssim") ] tf.summary.image("original", quantize_image(X)) tf.summary.image("reconstruction", quantize_image(X_pred)) hooks = [ tf.train.StopAtStepHook(last_step=args.num_steps), tf.train.NanTensorHook(train_loss), tf.train.LoggingTensorHook(logged_tensors, every_n_secs=60), tf.train.SummarySaverHook(save_steps=args.save_steps, summary_op=tf.summary.merge_all()), tf.train.CheckpointSaverHook(save_steps=args.save_steps, checkpoint_dir=args.checkpoint_dir) ] X_rec = tf.clip_by_value(X_pred, 0, 1) X_rec = tf.round(X_rec * 255) X_rec = tf.cast(X_rec, tf.uint8) X_ori = tf.clip_by_value(X, 0, 1) X_ori = tf.round(X_ori * 255) X_ori = tf.cast(X_ori, tf.uint8) if args.finetune != 'None': init_fn_ae = tf.contrib.framework.assign_from_checkpoint_fn(args.finetune,tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)) train_count = 0 parameter = 'VAE_%s' % (args.loss2) with tf.train.MonitoredTrainingSession( hooks=hooks) as sess: if args.finetune != 'None': init_fn_ae(sess) print('load from %s'%(args.finetune)) while not sess.should_stop(): if args.split != 'None': _, train_loss_ , d1_, d2_, kl_div_loss_, rec_img, X_ori_ = sess.run( [optimizer, train_loss, d1, d2, kl_div_loss, X_rec, X_ori]) if (train_count + 1) % args.display_steps == 0: f_log = open('%s/log.csv' % (args.checkpoint_dir), 'a') f_log.write('%d,loss,%f, kl,%f, d1,%f, d2,%f\n' % (train_count + 1, train_loss_, kl_div_loss_, d1_, d2_)) print('%d,loss,%f, kl,%f, d1,%f, d2,%f\n' % (train_count + 1, train_loss_, kl_div_loss_, d1_, d2_)) f_log.close() else: _, train_loss_ , d1_, kl_div_loss_, rec_img, X_ori_ = sess.run( [optimizer, train_loss, d1, kl_div_loss, X_rec, X_ori]) if (train_count + 1) % args.display_steps == 0: f_log = open('%s/log.csv' % (args.checkpoint_dir), 'a') f_log.write('%d,loss,%f, kl,%f, d1,%f\n' % (train_count + 1, train_loss_, kl_div_loss_, d1_)) print('%d,loss,%f, kl,%f, d1,%f\n' % (train_count + 1, train_loss_, kl_div_loss_, d1_)) f_log.close() if (train_count + 1) % args.save_steps == 0: num = math.floor(math.sqrt(rec_img.shape[0])) show_img = np.zeros([num * args.patch_size, num * args.patch_size, 3]) ori_img = np.zeros([num * args.patch_size, num * args.patch_size, 3]) for i in range(num): for j in range(num): show_img[i * args.patch_size:(i + 1) * args.patch_size, j * args.patch_size:(j + 1) * args.patch_size, :] = rec_img[num * i + j, :, :, :] ori_img[i * args.patch_size:(i + 1) * args.patch_size, j * args.patch_size:(j + 1) * args.patch_size, :] = X_ori_[num * i + j, :, :, :] save_name = os.path.join(args.checkpoint_dir, 'rec_%s_%s.png' % (parameter, train_count + 1)) scipy.misc.imsave(save_name, show_img) psnr_ = Psnr(ori_img, show_img) msssim_ = msssim(ori_img, show_img) # print('FOR calculation %s_%s_%s_%s_la1%s_la2%s_%s'%( # args.activation, args.dim1, args.dim2, args.z, args.lambda1, args.lambda2, train_count)) print("PSNR (dB), %.2f,Multiscale SSIM, %.4f,Multiscale SSIM (dB), %.2f" % ( psnr_, msssim_, -10 * np.log10(1 - msssim_))) f_log_ssim = open('%s/log_ssim_%s.csv' % (args.checkpoint_dir, parameter), 'a') f_log_ssim.write('%s,%d,PSNR (dB), %.2f,Multiscale SSIM, %.4f,Multiscale SSIM (dB), %.2f\n' % ( parameter, train_count + 1, psnr_, msssim_, -10 * np.log10(1 - msssim_) )) f_log_ssim.close() train_count += 1