def main(): parser = argparse.ArgumentParser() parser.add_argument('-c', '--continue', dest='continue_path', required=False) args = parser.parse_args() ## load dataset train_batch_gnr, train_set = get_dataset_batch(ds_name='train') test_gnr, test_set = get_dataset_batch(ds_name='test') ## build graph network = Model() placeholders, restored = network.build() gt_size = config.patch_size - config.edge gt = tf.placeholder(tf.float32, shape=(None, ) + (gt_size, gt_size) + (config.nr_channel * config.ratio * config.ratio, ), name='gt') loss_squared = squared_error_loss(gt, restored) loss_reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) loss = loss_reg + loss_squared ## train config global_steps = tf.Variable(0, trainable=False) boundaries = [ train_set.minibatchs_per_epoch * 5, train_set.minibatchs_per_epoch * 40 ] values = [0.0001, 0.0001, 0.0001] lr = tf.train.piecewise_constant(global_steps, boundaries, values) opt = tf.train.AdamOptimizer(lr) # in order to update BN in every iter update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train = opt.minimize(loss) ## init tensorboard tf.summary.scalar('loss_regularization', loss_reg) tf.summary.scalar('loss_error', loss - loss_reg) tf.summary.scalar('loss', loss) tf.summary.scalar('learning_rate', lr) merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter( os.path.join(config.log_dir, 'tf_log', 'train'), tf.get_default_graph()) ## create a session tf.set_random_seed(12345) # ensure consistent results global_cnt = 0 epoch_start = 0 g_list = tf.global_variables() saver = tf.train.Saver(var_list=g_list) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if args.continue_path: ckpt = tf.train.get_checkpoint_state(args.continue_path) saver.restore(sess, ckpt.model_checkpoint_path) epoch_start = int( ckpt.model_checkpoint_path.split('/')[-1].split('-')[1]) global_cnt = epoch_start * train_set.minibatchs_per_epoch ## training file = open('./psnr_ratioB4', 'w+') for epoch in range(epoch_start + 1, config.nr_epoch + 1): for _ in range(train_set.minibatchs_per_epoch): global_cnt += 1 lr_images, sr_images = sess.run(train_batch_gnr) # 128*7*7*3,128*9*9*27 feed_dict = { placeholders['data']: lr_images[:, :, :, :1], gt: sr_images[:, :, :, :1 * config.ratio * config.ratio], global_steps: global_cnt, placeholders['is_training']: True, } _, loss_v, loss_reg_v, lr_v, summary = sess.run( [train, loss, loss_reg, lr, merged], feed_dict=feed_dict) if global_cnt % config.show_interval == 0: train_writer.add_summary(summary, global_cnt) print( "e:{},{}/{}".format( epoch, global_cnt % train_set.minibatchs_per_epoch, train_set.minibatchs_per_epoch), 'loss: {:.3f}'.format(loss_v), 'loss_reg: {:.3f}'.format(loss_reg_v), 'lr: {:.4f}'.format(lr_v), ) ## save model if epoch % config.snapshot_interval == 0: saver.save(sess, os.path.join(config.log_model_dir, 'epoch-{}'.format(epoch)), global_step=global_cnt) if epoch % config.test_interval == 0: psnrs = [] for _ in range(test_set.testing_minibatchs_per_epoch): lr_image, hr_image = sess.run(test_gnr) feed_dict = { placeholders['data']: lr_image[:, :, :, :1], placeholders['is_training']: False, } restored_v = sess.run([restored], feed_dict=feed_dict) restored_img_y = from_sub_pixel_to_img( restored_v[0][0], config.ratio) if epoch == 199: img = np.clip(restored_img_y[:, :, 0], 0, 1) * 255 img = img.astype('uint8') cv2.imwrite('./output/{}.png'.format(global_cnt), img) global_cnt += 1 edge = int(config.edge / 2 * config.ratio) psnr_y = compare_psnr( hr_image[0, edge:-edge, edge:-edge, :1], restored_img_y) psnrs.append(psnr_y) file.write(str(np.mean(psnrs)) + '\n') print('average psnr is {:2.2f} dB'.format(np.mean(psnrs))) print('Training is done, exit.') file.close()
yield patch.astype(np.float32) / 255.0, hr_patch.astype( np.float32) / 255.0 if __name__ == "__main__": ds = Dataset('train') gen = ds.load().instance_generator() imggrid = [] while True: for i in range(8): img, hr_patch = next(gen) img = cv2.resize(img, ((config.patch_size) * config.ratio, (config.patch_size) * config.ratio)) start_idx = int(config.edge / 2 * config.ratio) end_idx = int((config.patch_size - config.edge / 2) * config.ratio) imggrid.append(img[start_idx:end_idx, start_idx:end_idx, :]) imggrid.append(from_sub_pixel_to_img(hr_patch, config.ratio)) imggrid = np.array(imggrid).reshape( (4, 4, hr_patch.shape[0] * config.ratio, hr_patch.shape[1] * config.ratio, 3)) imggrid = imggrid.transpose((0, 2, 1, 3, 4)).reshape( (4 * hr_patch.shape[0] * config.ratio, 4 * hr_patch.shape[1] * config.ratio, 3)) cv2.imshow('', imggrid) c = chr(cv2.waitKey(0) & 0xff) if c == 'q': exit() imggrid = []