def main(unused_argv): # load test images test_list = list_image(FLAGS.test_folder) # load model assert (FLAGS.snapshot_dir != "" or FLAGS.model_fname != ""), 'No pretrained model specified' model = Autoencoder(cfgs.patch_size * cfgs.patch_size, cfgs, log_dir=None) snapshot_fname = FLAGS.model_fname if FLAGS.model_fname != "" \ else tf.train.latest_checkpoint(FLAGS.snapshot_dir) model.restore(snapshot_fname) print('Restored from %s' % snapshot_fname) sum_psnr = 0.0 stride = FLAGS.stride for img_fname in test_list: orig_img = load_image('%s/%s' % (FLAGS.test_folder, img_fname)) # pre-process image gray_img = toGrayscale(orig_img) img = gray_img.astype(np.float32) img -= cfgs.mean_value img *= cfgs.scale # make measurement and reconstruct image recon_img = overlap_inference(model, img, bs=cfgs.batch_size, stride=stride) recon_img /= cfgs.scale recon_img += cfgs.mean_value # save reconstruction cv.imwrite( '%s/%sOI_%d_%s' % (FLAGS.reconstruction_folder, FLAGS.prefix, stride, img_fname), recon_img.astype(np.uint8)) psnr_ = psnr(gray_img.astype(np.float32), recon_img) print('Image %s, psnr: %f' % (img_fname, psnr_)) sum_psnr += psnr_ mean_psnr = sum_psnr / len(test_list) print('---------------------------') print('Mean PSNR: %f' % mean_psnr)
def main(unused_argv): val_losses = [] assert FLAGS.output_dir, "--output_dir is required" # Create training directory. output_dir = FLAGS.output_dir if not tf.gfile.IsDirectory(output_dir): tf.gfile.MakeDirs(output_dir) dl = DataLoader(FLAGS.db_fname, mean=cfgs.mean_value, scale=cfgs.scale, n_vals=FLAGS.n_vals) dl.prepare() x_dim = dl.get_data_dim() model = Autoencoder(x_dim, cfgs, log_dir=FLAGS.log_dir) model.quantize_weights() txt_log_fname = FLAGS.log_dir + 'text_log.txt' log_fout = open(txt_log_fname, 'w') if FLAGS.pretrained_fname: try: log_train(log_fout, 'Resume from %s' %(FLAGS.pretrained_fname)) model.restore(FLAGS.pretrained_fname) except: log_train(log_fout, 'Cannot restore from %s' %(FLAGS.pretrained_fname)) pass lr = cfgs.initial_lr epoch_counter = 0 ite = 0 while True: start = time.time() x, flag = dl.next_batch(cfgs.batch_size, 'train') load_data_time = time.time() - start if flag: epoch_counter += 1 do_log = (ite % FLAGS.log_every_n_steps == 0) or flag do_snapshot = flag and epoch_counter > 0 and epoch_counter % FLAGS.save_every_n_epochs == 0 val_loss = -1 # train one step start = time.time() loss, _, summary, ite = model.partial_fit(x, lr, do_log) one_iter_time = time.time() - start # writing outs if do_log: log_train(log_fout, 'Iteration %d, (lr=%f) training loss : %f' %(ite, lr, loss)) if FLAGS.log_time: log_train(log_fout, 'Iteration %d, data loading: %f(s) ; one iteration: %f(s)' %(ite, load_data_time, one_iter_time)) model.log(summary) if flag: val_loss = val(model, dl) val_losses.append(val_loss) log_train(log_fout, '----------------------------------------------------') if ite == 0: log_train(log_fout, 'Initial validation loss: %f' %(val_loss)) else: log_train(log_fout, 'Epoch %d, validation loss: %f' %(epoch_counter, val_loss)) log_train(log_fout, '----------------------------------------------------') model.log(summary) if do_snapshot: log_train(log_fout, 'Snapshotting') model.save(FLAGS.output_dir) if flag: if cfgs.lr_update == 'val' and len(val_losses) >= 5 and val_loss >= max(val_losses[-5:-1]): lr = lr * cfgs.lr_decay_factor log_train(log_fout, 'Decay learning rate to %f' %lr) elif cfgs.lr_update == 'step' and epoch_counter % cfgs.num_epochs_per_decay == 0: lr = lr * cfgs.lr_decay_factor log_train(log_fout, 'Decay learning rate to %f' %lr) if epoch_counter == FLAGS.n_epochs: if not do_snapshot: log_train(log_fout, 'Final snapshotting') model.save(FLAGS.output_dir) break log_fout.close()