def main(unused_argv): summ = Summaries() if FLAGS.data_dir == '' or not os.path.exists(FLAGS.data_dir): raise ValueError('invalid data directory {}'.format(FLAGS.data_dir)) if FLAGS.output_dir == '': raise ValueError('invalid output directory {}'.format(FLAGS.output_dir)) elif not os.path.exists(FLAGS.output_dir): os.makedirs(FLAGS.output_dir) event_log_dir = os.path.join(FLAGS.output_dir, '') checkpoint_path = os.path.join(FLAGS.output_dir, 'model.ckpt') print('Constructing models.') dim = [FLAGS.batch_size, FLAGS.img_height, FLAGS.img_width, FLAGS.num_chans] image_upsampler = Upsampler(dim) wm_upsampler = Upsampler([1] + dim[1:]) image_downsampler = Downsampler(dim) blender = Blender(dim) extrator = Extractor(dim) train_loss, train_op, train_summ_op = \ train(FLAGS.data_dir, image_upsampler, wm_upsampler, blender, image_downsampler, extrator, summ) val_summ_op = val(FLAGS.data_dir, image_upsampler, wm_upsampler, blender, image_downsampler, extrator, summ) print('Constructing saver.') saver = tf.train.Saver() # Start running operations on the Graph. allow_soft_placement must be set to # True to as some of the ops do not have GPU implementations. config = tf.ConfigProto(allow_soft_placement = True, log_device_placement = False) assert (FLAGS.gpus != ''), 'invalid GPU specification' config.gpu_options.visible_device_list = FLAGS.gpus # Build an initialization operation to run below. init = [tf.global_variables_initializer(), tf.local_variables_initializer()] with tf.Session(config = config) as sess: sess.run(init) writer = tf.summary.FileWriter(event_log_dir, graph = sess.graph) # Run training. for itr in range(FLAGS.num_iterations): cost, _, train_summ_str = sess.run([train_loss, train_op, train_summ_op]) # Print info: iteration #, cost. print(str(itr) + ' ' + str(cost)) writer.add_summary(train_summ_str, itr) if itr % FLAGS.validation_interval == 1: # Run through validation set. val_summ_str = sess.run(val_summ_op) writer.add_summary(val_summ_str, itr) tf.logging.info('Saving model.') saver.save(sess, checkpoint_path) tf.logging.info('Training complete')
def main(unused_argv): if FLAGS.checkpoint_dir == '' or not os.path.exists(FLAGS.checkpoint_dir): raise ValueError('invalid checkpoint directory {}'.format( FLAGS.checkpoint_dir)) checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, '') if FLAGS.output_dir == '': raise ValueError('invalid output directory {}'.format( FLAGS.output_dir)) elif not os.path.exists(FLAGS.output_dir): assert FLAGS.output_dir != FLAGS.checkpoint_dir os.makedirs(FLAGS.output_dir) print('reconstructing models and inputs.') image = Image('/data/yuming/watermark-data/image_paths.mat', FLAGS.image_seq)() wm = Watermark('/data/yuming/watermark-data/watermark.mat')() dim = [1, FLAGS.img_height, FLAGS.img_width, FLAGS.num_chans] image_upsampler = Upsampler(dim) wm_upsampler = Upsampler([1] + dim[1:]) downsampler = Downsampler(dim) blender = Blender(dim) extrator = Extractor(dim) image_upsampled = image_upsampler(image) wm_upsampled = wm_upsampler(wm) image_blended = blender(image_upsampled, wm_upsampled) image_downsampled = downsampler(image_blended) mask = Mask(FLAGS.img_height, FLAGS.img_width, 80)() mask = tf.cast(mask, tf.complex64) freqimage = FreqImage(mask) image_freqfiltered = freqimage(image_downsampled) wm_extracted = extrator(image_freqfiltered) enhance = Enhance(sharpen=True) wm_extracted = enhance(wm_extracted) saver = tf.train.Saver() writer = tf.summary.FileWriter(FLAGS.output_dir, tf.get_default_graph()) config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) assert (FLAGS.gpus != ''), 'invalid GPU specification' config.gpu_options.visible_device_list = FLAGS.gpus with tf.Session(config=config) as sess: sess.run(tf.local_variables_initializer()) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: # Restores from checkpoint saver.restore(sess, ckpt.model_checkpoint_path) # Assuming model_checkpoint_path looks something like: # /my-favorite-path/cifar10_train/model.ckpt-0, # extract global_step from it. global_step = ckpt.model_checkpoint_path.split('/')[-1].split( '-')[-1] else: print('No checkpoint file found') return wm_val, image_downsampled_val, image_freqfiltered_val, wm_extracted_val = \ sess.run([wm, image_downsampled, image_freqfiltered, wm_extracted]) images = [{ 'data': np.squeeze(image_downsampled_val[0, :, :, :].astype(np.uint8)), 'title': "watermarked image" }, { 'data': np.squeeze(image_freqfiltered_val[0, :, :, :].astype(np.uint8)), 'title': "filtered image" }, { 'data': np.squeeze(wm_val[0, :, :, :].astype(np.uint8)), 'title': "original watermark" }, { 'data': np.squeeze(wm_extracted_val[0, :, :, :].astype(np.uint8)), 'title': "extracted watermark" }] image_str = draw_image(images) writer.add_summary(image_str, global_step=0) np.set_printoptions(threshold=sys.maxsize) print(np.squeeze(wm_extracted_val)) writer.close()
def main(unused_argv): if FLAGS.checkpoint_dir == '' or not os.path.exists(FLAGS.checkpoint_dir): raise ValueError('invalid checkpoint directory {}'.format(FLAGS.checkpoint_dir)) checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, '') if FLAGS.output_dir == '': raise ValueError('invalid output directory {}'.format(FLAGS.output_dir)) elif not os.path.exists(FLAGS.output_dir): assert FLAGS.output_dir != FLAGS.checkpoint_dir os.makedirs(FLAGS.output_dir) print('reconstructing models and inputs.') image = Image('/data/yuming/watermark-data/image_paths.mat', FLAGS.image_seq)() wm = Watermark('/data/yuming/watermark-data/watermark.mat')() dim = [1, FLAGS.img_height, FLAGS.img_width, FLAGS.num_chans] image_upsampler = Upsampler(dim) wm_upsampler = Upsampler([1] + dim[1:]) downsampler = Downsampler(dim) blender = Blender(dim) extrator = Extractor(dim) image_upsampled = image_upsampler(image) wm_upsampled = wm_upsampler(wm) image_blended = blender(image_upsampled, wm_upsampled) image_downsampled = downsampler(image_blended) wm_extracted = extrator(image_downsampled) # Calculate the psnr of the model. psnr = PSNR() image_psnr = psnr(image, image_downsampled) wm_psnr = psnr(wm, wm_extracted) summ_psnr_op = tf.summary.merge([tf.summary.text('image_psnr', tf.as_string(image_psnr)), tf.summary.text('wm_psnr', tf.as_string(wm_psnr))]) saver = tf.train.Saver() writer = tf.summary.FileWriter(FLAGS.output_dir, tf.get_default_graph()) config = tf.ConfigProto(allow_soft_placement = True, log_device_placement = False) assert (FLAGS.gpus != ''), 'invalid GPU specification' config.gpu_options.visible_device_list = FLAGS.gpus with tf.Session(config = config) as sess: sess.run(tf.local_variables_initializer()) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: # Restores from checkpoint saver.restore(sess, ckpt.model_checkpoint_path) # Assuming model_checkpoint_path looks something like: # /my-favorite-path/cifar10_train/model.ckpt-0, # extract global_step from it. global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] else: print('No checkpoint file found') return summ_psnr_str, image_val, image_downsampled_val = \ sess.run([summ_psnr_op, image, image_downsampled]) writer.add_summary(summ_psnr_str, global_step = 0) ''' images = [{'data': np.squeeze(image_val[0, :, :, :].astype(np.uint8)), 'title': "original image"}, {'data': np.squeeze(image_downsampled_val[0, :, :, :].astype(np.uint8)), 'title': "watermarked image"}] ''' images = [{'data': np.squeeze(image_val[0, :, :, :].astype(np.uint8)), 'title': ""}, {'data': np.squeeze(image_downsampled_val[0, :, :, :].astype(np.uint8)), 'title': ""}] image_str = draw_image(images) writer.add_summary(image_str, global_step = 0)