Exemple #1
0
def main(unused_argv):
    summ = Summaries()

    if FLAGS.data_dir == '' or not os.path.exists(FLAGS.data_dir):
        raise ValueError('invalid data directory {}'.format(FLAGS.data_dir))

    if FLAGS.output_dir == '':
        raise ValueError('invalid output directory {}'.format(FLAGS.output_dir))
    elif not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)   

    event_log_dir = os.path.join(FLAGS.output_dir, '')
    
    checkpoint_path = os.path.join(FLAGS.output_dir, 'model.ckpt')

    print('Constructing models.')

    dim = [FLAGS.batch_size, FLAGS.img_height, FLAGS.img_width, FLAGS.num_chans]
    image_upsampler = Upsampler(dim)
    wm_upsampler = Upsampler([1] + dim[1:])
    image_downsampler = Downsampler(dim)
    blender = Blender(dim)
    extrator = Extractor(dim)

    train_loss, train_op, train_summ_op = \
        train(FLAGS.data_dir, image_upsampler, wm_upsampler, blender, image_downsampler, extrator, summ)
    val_summ_op = val(FLAGS.data_dir, image_upsampler, wm_upsampler, blender, image_downsampler, extrator, summ)

    print('Constructing saver.')
    saver = tf.train.Saver()

    # Start running operations on the Graph. allow_soft_placement must be set to
    # True to as some of the ops do not have GPU implementations.
    config = tf.ConfigProto(allow_soft_placement = True, log_device_placement = False)

    assert (FLAGS.gpus != ''), 'invalid GPU specification'
    config.gpu_options.visible_device_list = FLAGS.gpus

    # Build an initialization operation to run below.
    init = [tf.global_variables_initializer(), tf.local_variables_initializer()]

    with tf.Session(config = config) as sess:
        sess.run(init)

        writer = tf.summary.FileWriter(event_log_dir, graph = sess.graph)

        # Run training.
        for itr in range(FLAGS.num_iterations):
            cost, _, train_summ_str = sess.run([train_loss, train_op, train_summ_op])
            # Print info: iteration #, cost.
            print(str(itr) + ' ' + str(cost))

            writer.add_summary(train_summ_str, itr)

            if itr % FLAGS.validation_interval == 1:
                # Run through validation set.
                val_summ_str = sess.run(val_summ_op)
                writer.add_summary(val_summ_str, itr)

        tf.logging.info('Saving model.')
        saver.save(sess, checkpoint_path)
        tf.logging.info('Training complete')
Exemple #2
0
def main(unused_argv):
    if FLAGS.checkpoint_dir == '' or not os.path.exists(FLAGS.checkpoint_dir):
        raise ValueError('invalid checkpoint directory {}'.format(
            FLAGS.checkpoint_dir))

    checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, '')

    if FLAGS.output_dir == '':
        raise ValueError('invalid output directory {}'.format(
            FLAGS.output_dir))
    elif not os.path.exists(FLAGS.output_dir):
        assert FLAGS.output_dir != FLAGS.checkpoint_dir
        os.makedirs(FLAGS.output_dir)

    print('reconstructing models and inputs.')
    image = Image('/data/yuming/watermark-data/image_paths.mat',
                  FLAGS.image_seq)()
    wm = Watermark('/data/yuming/watermark-data/watermark.mat')()

    dim = [1, FLAGS.img_height, FLAGS.img_width, FLAGS.num_chans]
    image_upsampler = Upsampler(dim)
    wm_upsampler = Upsampler([1] + dim[1:])
    downsampler = Downsampler(dim)
    blender = Blender(dim)
    extrator = Extractor(dim)

    image_upsampled = image_upsampler(image)
    wm_upsampled = wm_upsampler(wm)
    image_blended = blender(image_upsampled, wm_upsampled)
    image_downsampled = downsampler(image_blended)

    mask = Mask(FLAGS.img_height, FLAGS.img_width, 80)()
    mask = tf.cast(mask, tf.complex64)
    freqimage = FreqImage(mask)

    image_freqfiltered = freqimage(image_downsampled)
    wm_extracted = extrator(image_freqfiltered)

    enhance = Enhance(sharpen=True)
    wm_extracted = enhance(wm_extracted)

    saver = tf.train.Saver()
    writer = tf.summary.FileWriter(FLAGS.output_dir, tf.get_default_graph())

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    assert (FLAGS.gpus != ''), 'invalid GPU specification'
    config.gpu_options.visible_device_list = FLAGS.gpus

    with tf.Session(config=config) as sess:
        sess.run(tf.local_variables_initializer())

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            # Assuming model_checkpoint_path looks something like:
            #   /my-favorite-path/cifar10_train/model.ckpt-0,
            # extract global_step from it.
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                '-')[-1]
        else:
            print('No checkpoint file found')
            return

        wm_val, image_downsampled_val, image_freqfiltered_val, wm_extracted_val = \
            sess.run([wm, image_downsampled, image_freqfiltered, wm_extracted])

        images = [{
            'data':
            np.squeeze(image_downsampled_val[0, :, :, :].astype(np.uint8)),
            'title':
            "watermarked image"
        }, {
            'data':
            np.squeeze(image_freqfiltered_val[0, :, :, :].astype(np.uint8)),
            'title':
            "filtered image"
        }, {
            'data': np.squeeze(wm_val[0, :, :, :].astype(np.uint8)),
            'title': "original watermark"
        }, {
            'data':
            np.squeeze(wm_extracted_val[0, :, :, :].astype(np.uint8)),
            'title':
            "extracted watermark"
        }]

        image_str = draw_image(images)
        writer.add_summary(image_str, global_step=0)

        np.set_printoptions(threshold=sys.maxsize)
        print(np.squeeze(wm_extracted_val))

    writer.close()
Exemple #3
0
def main(unused_argv):
    if FLAGS.checkpoint_dir == '' or not os.path.exists(FLAGS.checkpoint_dir):
        raise ValueError('invalid checkpoint directory {}'.format(FLAGS.checkpoint_dir))

    checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, '')

    if FLAGS.output_dir == '':
        raise ValueError('invalid output directory {}'.format(FLAGS.output_dir))
    elif not os.path.exists(FLAGS.output_dir):
        assert FLAGS.output_dir != FLAGS.checkpoint_dir
        os.makedirs(FLAGS.output_dir)

    print('reconstructing models and inputs.')
    image = Image('/data/yuming/watermark-data/image_paths.mat', FLAGS.image_seq)()
    wm = Watermark('/data/yuming/watermark-data/watermark.mat')()

    dim = [1, FLAGS.img_height, FLAGS.img_width, FLAGS.num_chans]
    image_upsampler = Upsampler(dim)
    wm_upsampler = Upsampler([1] + dim[1:])
    downsampler = Downsampler(dim)
    blender = Blender(dim)
    extrator = Extractor(dim)

    image_upsampled = image_upsampler(image)
    wm_upsampled = wm_upsampler(wm)
    image_blended = blender(image_upsampled, wm_upsampled)
    image_downsampled = downsampler(image_blended)
    wm_extracted = extrator(image_downsampled)

    # Calculate the psnr of the model.
    psnr = PSNR()
    image_psnr = psnr(image, image_downsampled)
    wm_psnr = psnr(wm, wm_extracted)
    
    summ_psnr_op = tf.summary.merge([tf.summary.text('image_psnr', tf.as_string(image_psnr)),
                                     tf.summary.text('wm_psnr', tf.as_string(wm_psnr))])
    saver = tf.train.Saver()
    writer = tf.summary.FileWriter(FLAGS.output_dir, tf.get_default_graph()) 
    
    config = tf.ConfigProto(allow_soft_placement = True, log_device_placement = False)
    assert (FLAGS.gpus != ''), 'invalid GPU specification'
    config.gpu_options.visible_device_list = FLAGS.gpus

    with tf.Session(config = config) as sess:
        sess.run(tf.local_variables_initializer())

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            # Assuming model_checkpoint_path looks something like:
            #   /my-favorite-path/cifar10_train/model.ckpt-0,
            # extract global_step from it.
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        else:
            print('No checkpoint file found')
            return

        summ_psnr_str, image_val, image_downsampled_val = \
            sess.run([summ_psnr_op, image, image_downsampled])

        writer.add_summary(summ_psnr_str, global_step = 0)

        '''
        images = [{'data': np.squeeze(image_val[0, :, :, :].astype(np.uint8)), 'title': "original image"},
                  {'data': np.squeeze(image_downsampled_val[0, :, :, :].astype(np.uint8)), 'title': "watermarked image"}]
        '''

        images = [{'data': np.squeeze(image_val[0, :, :, :].astype(np.uint8)), 'title': ""},
                  {'data': np.squeeze(image_downsampled_val[0, :, :, :].astype(np.uint8)), 'title': ""}]

        
        image_str = draw_image(images)
        writer.add_summary(image_str, global_step = 0)