def main(_):
    with tf.Graph().as_default():
        # Create inputs in [0, 1], as expected by vgg_16.
        inputs, _ = image_utils.imagenet_inputs(FLAGS.batch_size,
                                                FLAGS.image_size)
        evaluation_images = image_utils.load_evaluation_images(
            FLAGS.image_size)

        # Process style and weight flags
        if FLAGS.style_coefficients is None:
            style_coefficients = [1.0 for _ in range(FLAGS.num_styles)]
        else:
            style_coefficients = ast.literal_eval(FLAGS.style_coefficients)
        if len(style_coefficients) != FLAGS.num_styles:
            raise ValueError(
                'number of style coefficients differs from number of styles')
        content_weights = ast.literal_eval(FLAGS.content_weights)
        style_weights = ast.literal_eval(FLAGS.style_weights)

        # Load style images.
        style_images, labels, style_gram_matrices = image_utils.style_image_inputs(
            os.path.expanduser(FLAGS.style_dataset_file),
            batch_size=FLAGS.num_styles,
            image_size=FLAGS.image_size,
            square_crop=True,
            shuffle=False)
        labels = tf.unstack(labels)

        def _create_normalizer_params(style_label):
            """Creates normalizer parameters from a style label."""
            return {
                'labels': tf.expand_dims(style_label, 0),
                'num_categories': FLAGS.num_styles,
                'center': True,
                'scale': True
            }

        # Dummy call to simplify the reuse logic
        model.transform(inputs,
                        reuse=False,
                        normalizer_params=_create_normalizer_params(labels[0]))

        def _style_sweep(inputs):
            """Transfers all styles onto the input one at a time."""
            inputs = tf.expand_dims(inputs, 0)
            stylized_inputs = [
                model.transform(
                    inputs,
                    reuse=True,
                    normalizer_params=_create_normalizer_params(style_label))
                for _, style_label in enumerate(labels)
            ]
            return tf.concat(0, [inputs] + stylized_inputs)

        if FLAGS.style_grid:
            style_row = tf.concat(0, [
                tf.ones([1, FLAGS.image_size, FLAGS.image_size, 3]),
                style_images
            ])
            stylized_training_example = _style_sweep(inputs[0])
            stylized_evaluation_images = [
                _style_sweep(image) for image in tf.unstack(evaluation_images)
            ]
            stylized_noise = _style_sweep(
                tf.random_uniform([FLAGS.image_size, FLAGS.image_size, 3]))
            stylized_style_images = [
                _style_sweep(image) for image in tf.unstack(style_images)
            ]
            if FLAGS.style_crossover:
                grid = tf.concat(
                    0, [style_row, stylized_training_example, stylized_noise] +
                    stylized_evaluation_images + stylized_style_images)
            else:
                grid = tf.concat(
                    0, [style_row, stylized_training_example, stylized_noise] +
                    stylized_evaluation_images)
            tf.image_summary(
                'Style Grid',
                tf.cast(
                    image_utils.form_image_grid(grid, ([
                        3 + evaluation_images.get_shape().as_list()[0] +
                        FLAGS.num_styles, 1 + FLAGS.num_styles
                    ] if FLAGS.style_crossover else [
                        3 + evaluation_images.get_shape().as_list()[0], 1 +
                        FLAGS.num_styles
                    ]), [FLAGS.image_size, FLAGS.image_size], 3) * 255.0,
                    tf.uint8))

        if FLAGS.learning_curves:
            metrics = {}
            for i, label in enumerate(labels):
                gram_matrices = dict([
                    (key, value[i:i + 1])
                    for key, value in style_gram_matrices.iteritems()
                ])
                stylized_inputs = model.transform(
                    inputs,
                    reuse=True,
                    normalizer_params=_create_normalizer_params(label))
                _, loss_dict = learning.total_loss(inputs,
                                                   stylized_inputs,
                                                   gram_matrices,
                                                   content_weights,
                                                   style_weights,
                                                   reuse=i > 0)
                for key, value in loss_dict.iteritems():
                    metrics['{}_style_{}'.format(
                        key, i)] = slim.metrics.streaming_mean(value)

            names_values, names_updates = slim.metrics.aggregate_metric_map(
                metrics)
            for name, value in names_values.iteritems():
                summary_op = tf.scalar_summary(name, value, [])
                print_op = tf.Print(summary_op, [value], name)
                tf.add_to_collection(tf.GraphKeys.SUMMARIES, print_op)
            eval_op = names_updates.values()
            num_evals = FLAGS.num_evals
        else:
            eval_op = None
            num_evals = 1

        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=os.path.expanduser(FLAGS.train_dir),
            logdir=os.path.expanduser(FLAGS.eval_dir),
            eval_op=eval_op,
            num_evals=num_evals,
            eval_interval_secs=FLAGS.eval_interval_secs)
Beispiel #2
0
def main(unused_argv=None):
    with tf.Graph().as_default():
        # Force all input processing onto CPU in order to reserve the GPU for the
        # forward inference and back-propagation.
        device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               worker_device=device)):
            inputs, _ = image_utils.imagenet_inputs(FLAGS.batch_size,
                                                    FLAGS.image_size)
            # Load style images and select one at random (for each graph execution, a
            # new random selection occurs)
            _, style_labels, style_gram_matrices = image_utils.style_image_inputs(
                os.path.expanduser(FLAGS.style_dataset_file),
                batch_size=FLAGS.batch_size,
                image_size=FLAGS.image_size,
                square_crop=True,
                shuffle=True)

        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Process style and weight flags
            num_styles = FLAGS.num_styles
            if FLAGS.style_coefficients is None:
                style_coefficients = [1.0 for _ in range(num_styles)]
            else:
                style_coefficients = ast.literal_eval(FLAGS.style_coefficients)
            if len(style_coefficients) != num_styles:
                raise ValueError(
                    'number of style coefficients differs from number of styles'
                )
            content_weights = ast.literal_eval(FLAGS.content_weights)
            style_weights = ast.literal_eval(FLAGS.style_weights)

            # Rescale style weights dynamically based on the current style image
            style_coefficient = tf.gather(tf.constant(style_coefficients),
                                          style_labels)
            style_weights = dict((key, style_coefficient * value)
                                 for key, value in style_weights.iteritems())

            # Define the model
            stylized_inputs = model.transform(inputs,
                                              normalizer_params={
                                                  'labels': style_labels,
                                                  'num_categories': num_styles,
                                                  'center': True,
                                                  'scale': True
                                              })

            # Compute losses.
            total_loss, loss_dict = learning.total_loss(
                inputs, stylized_inputs, style_gram_matrices, content_weights,
                style_weights)
            for key, value in loss_dict.iteritems():
                tf.summary.scalar(key, value)

            # Set up training
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            train_op = slim.learning.create_train_op(
                total_loss,
                optimizer,
                clip_gradient_norm=FLAGS.clip_gradient_norm,
                summarize_gradients=False)

            # Function to restore VGG16 parameters
            # TODO(iansimon): This is ugly, but assign_from_checkpoint_fn doesn't
            # exist yet.
            saver = tf.train.Saver(slim.get_variables('vgg_16'))

            def init_fn(session):
                saver.restore(session, vgg.checkpoint_file())

            # Run training
            slim.learning.train(train_op=train_op,
                                logdir=os.path.expanduser(FLAGS.train_dir),
                                master=FLAGS.master,
                                is_chief=FLAGS.task == 0,
                                number_of_steps=FLAGS.train_steps,
                                init_fn=init_fn,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)