예제 #1
0
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MkDir(FLAGS.output_dir)

    with tf.Graph().as_default(), tf.Session() as sess:
        # Defines place holder for the style image.
        style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if FLAGS.style_square_crop:
            style_img_preprocessed = image_utils.center_crop_resize_image(
                style_img_ph, FLAGS.style_image_size)
        else:
            style_img_preprocessed = image_utils.resize_image(
                style_img_ph, FLAGS.style_image_size)

        # Defines place holder for the content image.
        content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if FLAGS.content_square_crop:
            content_img_preprocessed = image_utils.center_crop_resize_image(
                content_img_ph, FLAGS.image_size)
        else:
            content_img_preprocessed = image_utils.resize_image(
                content_img_ph, FLAGS.image_size)

        # Defines the model.
        stylized_images, _, _, bottleneck_feat = build_model.build_model(
            content_img_preprocessed,
            style_img_preprocessed,
            trainable=False,
            is_training=False,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        if tf.gfile.IsDirectory(FLAGS.checkpoint):
            checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
        else:
            checkpoint = FLAGS.checkpoint
            tf.logging.info(
                'loading latest checkpoint file: {}'.format(checkpoint))

        init_fn = slim.assign_from_checkpoint_fn(
            checkpoint, slim.get_variables_to_restore())
        sess.run([tf.local_variables_initializer()])
        init_fn(sess)

        # Gets the list of the input style images.
        style_img_list = tf.gfile.Glob(FLAGS.style_images_paths)
        if len(style_img_list) > FLAGS.maximum_styles_to_evaluate:
            np.random.seed(1234)
            style_img_list = np.random.permutation(style_img_list)
            style_img_list = style_img_list[:FLAGS.maximum_styles_to_evaluate]

        # Gets list of input content images.
        content_img_list = tf.gfile.Glob(FLAGS.content_images_paths)

        for content_i, content_img_path in enumerate(content_img_list):
            content_img_np = image_utils.load_np_image_uint8(
                content_img_path)[:, :, :3]
            content_img_name = os.path.basename(content_img_path)[:-4]

            # Saves preprocessed content image.
            inp_img_croped_resized_np = sess.run(
                content_img_preprocessed,
                feed_dict={content_img_ph: content_img_np})
            image_utils.save_np_image(
                inp_img_croped_resized_np,
                os.path.join(FLAGS.output_dir, '%s.jpg' % (content_img_name)))

            # Computes bottleneck features of the style prediction network for the
            # identity transform.
            identity_params = sess.run(
                bottleneck_feat, feed_dict={style_img_ph: content_img_np})

            for style_i, style_img_path in enumerate(style_img_list):
                if style_i > FLAGS.maximum_styles_to_evaluate:
                    break
                style_img_name = os.path.basename(style_img_path)[:-4]
                style_image_np = image_utils.load_np_image_uint8(
                    style_img_path)[:, :, :3]

                if style_i % 10 == 0:
                    tf.logging.info(
                        'Stylizing (%d) %s with (%d) %s' %
                        (content_i, content_img_name, style_i, style_img_name))

                # Saves preprocessed style image.
                style_img_croped_resized_np = sess.run(
                    style_img_preprocessed,
                    feed_dict={style_img_ph: style_image_np})
                image_utils.save_np_image(
                    style_img_croped_resized_np,
                    os.path.join(FLAGS.output_dir,
                                 '%s.jpg' % (style_img_name)))

                # Computes bottleneck features of the style prediction network for the
                # given style image.
                style_params = sess.run(
                    bottleneck_feat, feed_dict={style_img_ph: style_image_np})

                interpolation_weights = ast.literal_eval(
                    FLAGS.interpolation_weights)
                # Interpolates between the parameters of the identity transform and
                # style parameters of the given style image.
                for interp_i, wi in enumerate(interpolation_weights):
                    stylized_image_res = sess.run(
                        stylized_images,
                        feed_dict={
                            bottleneck_feat:
                            identity_params * (1 - wi) + style_params * wi,
                            content_img_ph:
                            content_img_np
                        })

                    # Saves stylized image.
                    image_utils.save_np_image(
                        stylized_image_res,
                        os.path.join(
                            FLAGS.output_dir, '%s_stylized_%s_%d.jpg' %
                            (content_img_name, style_img_name, interp_i)))
예제 #2
0
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # Forces all input processing onto CPU in order to reserve the GPU for the
        # forward inference and back-propagation.
        device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               worker_device=device)):
            # Loads content images.
            content_inputs_, _ = image_utils.imagenet_inputs(
                FLAGS.batch_size, FLAGS.image_size)

            # Loads style images.
            [style_inputs_, _,
             style_inputs_orig_] = image_utils.arbitrary_style_image_inputs(
                 FLAGS.style_dataset_file,
                 batch_size=FLAGS.batch_size,
                 image_size=FLAGS.image_size,
                 shuffle=True,
                 center_crop=FLAGS.center_crop,
                 augment_style_images=FLAGS.augment_style_images,
                 random_style_image_size=FLAGS.random_style_image_size)

        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Process style and content weight flags.
            content_weights = ast.literal_eval(FLAGS.content_weights)
            style_weights = ast.literal_eval(FLAGS.style_weights)

            # Define the model
            stylized_images, total_loss, loss_dict, _ = build_model.build_model(
                content_inputs_,
                style_inputs_,
                trainable=True,
                is_training=True,
                inception_end_point='Mixed_6e',
                style_prediction_bottleneck=100,
                adds_losses=True,
                content_weights=content_weights,
                style_weights=style_weights,
                total_variation_weight=FLAGS.total_variation_weight)

            # Adding scalar summaries to the tensorboard.
            for key, value in loss_dict.iteritems():
                tf.summary.scalar(key, value)

            # Adding Image summaries to the tensorboard.
            tf.summary.image('image/0_content_inputs', content_inputs_, 3)
            tf.summary.image('image/1_style_inputs_orig', style_inputs_orig_,
                             3)
            tf.summary.image('image/2_style_inputs_aug', style_inputs_, 3)
            tf.summary.image('image/3_stylized_images', stylized_images, 3)

            # Set up training
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            train_op = slim.learning.create_train_op(
                total_loss,
                optimizer,
                clip_gradient_norm=FLAGS.clip_gradient_norm,
                summarize_gradients=False)

            # Function to restore VGG16 parameters.
            init_fn_vgg = slim.assign_from_checkpoint_fn(
                vgg.checkpoint_file(), slim.get_variables('vgg_16'))

            # Function to restore Inception_v3 parameters.
            inception_variables_dict = {
                var.op.name: var
                for var in slim.get_model_variables('InceptionV3')
            }
            init_fn_inception = slim.assign_from_checkpoint_fn(
                FLAGS.inception_v3_checkpoint, inception_variables_dict)

            # Function to restore VGG16 and Inception_v3 parameters.
            def init_sub_networks(session):
                init_fn_vgg(session)
                init_fn_inception(session)

            # Run training
            slim.learning.train(train_op=train_op,
                                logdir=os.path.expanduser(FLAGS.train_dir),
                                master=FLAGS.master,
                                is_chief=FLAGS.task == 0,
                                number_of_steps=FLAGS.train_steps,
                                init_fn=init_sub_networks,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
예제 #3
0
def main(unused_argv=None):
    print('timer start')
    start = time.time()

    words = [
        '布', '植物', 'ガラス', '革', '金属', '紙', 'プラスチック', '石', '水', '木', '樹脂',
        'アクリル', 'アルミニウム', '牛皮', 'レンガ', '絹'
    ]

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
    sess = tf.Session(config=config)

    tf.logging.set_verbosity(tf.logging.INFO)
    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MkDir(FLAGS.output_dir)

    with tf.Graph().as_default(), sess:
        # Defines place holder for the style image.
        style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if FLAGS.style_square_crop:
            style_img_preprocessed = image_utils.center_crop_resize_image(
                style_img_ph, FLAGS.style_image_size)
        else:
            style_img_preprocessed = image_utils.resize_image(
                style_img_ph, FLAGS.style_image_size)

        # Defines place holder for the content image.
        content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if FLAGS.content_square_crop:
            content_img_preprocessed = image_utils.center_crop_resize_image(
                content_img_ph, FLAGS.image_size)
        else:
            content_img_preprocessed = image_utils.resize_image(
                content_img_ph, FLAGS.image_size)

        # Defines the model.
        stylized_images, _, _, bottleneck_feat = build_model.build_model(
            content_img_preprocessed,
            style_img_preprocessed,
            trainable=False,
            is_training=False,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        if tf.gfile.IsDirectory(FLAGS.checkpoint):
            checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
        else:
            checkpoint = FLAGS.checkpoint
            tf.logging.info(
                'loading latest checkpoint file: {}'.format(checkpoint))

        init_fn = slim.assign_from_checkpoint_fn(
            checkpoint, slim.get_variables_to_restore())
        sess.run([tf.local_variables_initializer()])
        init_fn(sess)

        # Gets the list of the input style images.
        style_img_list = tf.gfile.Glob(FLAGS.style_images_paths)
        if len(style_img_list) > FLAGS.maximum_styles_to_evaluate:
            np.random.seed(1234)
            style_img_list = np.random.permutation(style_img_list)
            style_img_list = style_img_list[:FLAGS.maximum_styles_to_evaluate]

        # Gets list of input content images.
        content_img_list = tf.gfile.Glob(FLAGS.content_images_paths)

        j = -1
        for content_i, content_img_path in enumerate(content_img_list):
            j += 1
            content_img_np = image_utils.load_np_image_uint8(
                content_img_path)[:, :, :3]
            content_img_name = os.path.basename(content_img_path)[:-4]

            # Saves preprocessed content image.
            inp_img_croped_resized_np = sess.run(
                content_img_preprocessed,
                feed_dict={content_img_ph: content_img_np})
            image_utils.save_np_image(
                inp_img_croped_resized_np,
                os.path.join(FLAGS.output_dir, '%s.jpg' % (content_img_name)))

            if FLAGS.color_preserve is True:
                print('color preserve mode!')
                # convert content iamge to ycc from bgr
                height, width, channels = inp_img_croped_resized_np[
                    0].shape[:3]
                # print(inp_img_croped_resized_np)
                wgap = 4 - (
                    width % 4
                )  # fuking translater made some gaps because of decoder
                hgap = 4 - (height % 4)
                inp_img_croped_resized_np = inp_img_croped_resized_np * 255
                content_img_np_ycc = cv2.cvtColor(inp_img_croped_resized_np[0],
                                                  cv2.COLOR_RGB2YCR_CB)
                # print(content_img_np_ycc)
                zeros = np.zeros((height, width), content_img_np_ycc.dtype)
                zeros = zeros + 128  # YCC's zero is center of 255
                tmp = cv2.cvtColor(content_img_np_ycc, cv2.COLOR_YCR_CB2BGR)
                cv2.imwrite("gray.jpg", tmp)
                # print(zeros)
                Ycontent, Crcontent, Cbcontent = cv2.split(content_img_np_ycc)
                # print(Ycontent.shape, Crcontent.shape, Cbcontent.shape, zeros.shape)
                # print(Crcontent)
                # print(content_img_np_ycc)
                # content_img_np_ycc_y = cv2.merge((Y, zeros, zeros))
                # content_img_np_gry = cv2.cvtColor(content_img_np_ycc_y, cv2.COLOR_YCR_CB2RGB)
                # print(content_img_np_gry)
                # cv2.imwrite("gray.jpg", content_img_np_gry)
                # print(np.shape(content_img_np))
                # content_img_np = content_img_np_gry

            # Computes bottleneck features of the style prediction network for the
            # identity transform.

            identity_params = sess.run(
                bottleneck_feat, feed_dict={style_img_ph: content_img_np})

            i = 0
            for word in words:
                # word = words[i]
                print(word)
                i += 1
                # if style_i > FLAGS.maximum_styles_to_evaluate:
                # break
                # style_img_name = os.path.basename(style_img_path)[:-4]
                # style_image_np = image_utils.load_np_image_uint8(style_img_path)[:, :, :
                # 3]

                # if style_i % 10 == 0:
                # tf.logging.info('Stylizing (%d) %s with (%d) %s' %
                # (content_i, content_img_name, style_i,
                # style_img_name))

                # Saves preprocessed style image.
                # style_img_croped_resized_np = sess.run(
                # style_img_preprocessed, feed_dict={
                # style_img_ph: style_image_np
                # })
                # image_utils.save_np_image(style_img_croped_resized_np,
                # os.path.join(FLAGS.output_dir,
                # '%s.jpg' % (style_img_name)))

                # Computes bottleneck features of the style prediction network for the
                # given style image.
                # style_params_ori = sess.run(
                # bottleneck_feat, feed_dict={style_img_ph: style_image_np})

                # print(np.shape(style_params))
                picklename = 'params/{}_{}.pickle'.format(word, j)
                f = open(picklename, 'r')
                style_params = pickle.load(f)
                # print(style_params)

                # print('diff of original para and made para:')
                # print(style_params_ori - style_params)

                interpolation_weights = ast.literal_eval(
                    FLAGS.interpolation_weights)
                # Interpolates between the parameters of the identity transform and
                # style parameters of the given style image.
                for interp_i, wi in enumerate(interpolation_weights):
                    stylized_image_res = sess.run(
                        stylized_images,
                        feed_dict={
                            bottleneck_feat:
                            identity_params * (1 - wi) + style_params * wi,
                            content_img_ph:
                            content_img_np
                        })

                    if FLAGS.color_preserve is True:
                        # print(stylized_image_res[0].shape)
                        stylized_image_res_ycc = cv2.cvtColor(
                            stylized_image_res[0], cv2.COLOR_RGB2YCR_CB)
                        Ystylized, Crstylized, Cbstylized = cv2.split(
                            stylized_image_res_ycc)
                        if wgap == 4:  # if original image is just fit
                            Ystylized_crop = Ystylized * 255
                        else:
                            Ystylized_crop = Ystylized[:, :-1 * wgap] * 255
                        if hgap == 4:
                            Ystylized_crop = Ystylized_crop
                        else:
                            Ystylized_crop = Ystylized_crop[:-1 * hgap, :]
                        print(Ystylized_crop.shape, Cbcontent.shape)
                        # print(wgap)
                        swapped_ycc = cv2.merge(
                            (Ystylized_crop, Crcontent, Cbcontent))
                        # print(swapped_ycc)
                        stylized_image_res = cv2.cvtColor(
                            swapped_ycc, cv2.COLOR_YCR_CB2BGR)
                        # print(stylized_image_res)
                        cv2.imwrite(
                            os.path.join(
                                FLAGS.output_dir, '%s_stylized_%s_%d.jpg' %
                                (content_img_name, word, interp_i)),
                            stylized_image_res)

                    # Saves stylized image.
                    else:
                        image_utils.save_np_image(
                            stylized_image_res,
                            os.path.join(
                                FLAGS.output_dir, '%s_stylized_%s_%d.jpg' %
                                (content_img_name, word, interp_i)))

    elapsed_time = time.time() - start
    print("timer stop")
    print(elapsed_time)
예제 #4
0
def styleParam(style_img_list):
    config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
    sess = tf.Session(config=config)
    slim = tf.contrib.slim
    checkpoint_in = 'arbitrary_style_transfer/model.ckpt'
    output_dir = 'outputs_style'
    image_size = 256
    content_square_crop = False
    style_image_size = 256
    style_square_crop = False
    maximum_styles_to_evaluate = 1024
    interpolation_weights_in = '[1.0]'

    tf.logging.set_verbosity(tf.logging.INFO)
    style_param_matrix = []
    if not tf.gfile.Exists(output_dir):
        tf.gfile.MkDir(output_dir)

    with tf.Graph().as_default(), sess:
        # Defines place holder for the style image.
        style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if style_square_crop:
            style_img_preprocessed = image_utils.center_crop_resize_image(
                style_img_ph, style_image_size)
        else:
            style_img_preprocessed = image_utils.resize_image(
                style_img_ph, style_image_size)

        # Defines place holder for the content image.
        content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if content_square_crop:
            content_img_preprocessed = image_utils.center_crop_resize_image(
                content_img_ph, image_size)
        else:
            content_img_preprocessed = image_utils.resize_image(
                content_img_ph, image_size)

        # Defines the model.
        stylized_images, _, _, bottleneck_feat = build_model.build_model(
            content_img_preprocessed,
            style_img_preprocessed,
            trainable=False,
            is_training=False,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        if tf.gfile.IsDirectory(checkpoint_in):
            checkpoint = tf.train.latest_checkpoint(checkpoint_in)
        else:
            checkpoint = checkpoint_in
            tf.logging.info(
                'loading latest checkpoint file: {}'.format(checkpoint))

        init_fn = slim.assign_from_checkpoint_fn(
            checkpoint, slim.get_variables_to_restore())
        sess.run([tf.local_variables_initializer()])
        init_fn(sess)

        # Gets the list of the input style images.
        # style_img_list = tf.gfile.Glob(style_images_paths)
        # if len(style_img_list) > maximum_styles_to_evaluate:
        # np.random.seed(1234)
        # style_img_list = np.random.permutation(style_img_list)
        # style_img_list = style_img_list[:maximum_styles_to_evaluate]

        count = 0
        for style_i, style_img_path in enumerate(style_img_list):
            # if style_i > maximum_styles_to_evaluate:
            # break
            style_image_np = image_utils.load_np_image_uint8(
                style_img_path)[:, :, :3]
            # Computes bottleneck features of the style prediction network for the
            # given style image.
            style_params = sess.run(bottleneck_feat,
                                    feed_dict={style_img_ph: style_image_np})
            style_param_matrix.append(style_params)
            count += 1
            # if count % 100 == 0:
            #     print(count, '/', len(style_img_list))

    # style_param_matrix is (num_of_images,100) vector
    return (style_param_matrix)
예제 #5
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    with tf.Graph().as_default():
        # Loads content images.
        eval_content_inputs_, _ = image_utils.imagenet_inputs(
            FLAGS.batch_size, FLAGS.image_size)

        # Process style and content weight flags.
        content_weights = ast.literal_eval(FLAGS.content_weights)
        style_weights = ast.literal_eval(FLAGS.style_weights)

        # Loads evaluation style images.
        eval_style_inputs_, _, _ = image_utils.arbitrary_style_image_inputs(
            FLAGS.eval_style_dataset_file,
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            center_crop=True,
            shuffle=True,
            augment_style_images=False,
            random_style_image_size=False)

        # Computes stylized noise.
        stylized_noise, _, _, _ = build_model.build_model(
            tf.random_uniform([
                min(4, FLAGS.batch_size), FLAGS.image_size, FLAGS.image_size, 3
            ]),
            tf.slice(eval_style_inputs_, [0, 0, 0, 0],
                     [min(4, FLAGS.batch_size), -1, -1, -1]),
            trainable=False,
            is_training=False,
            reuse=None,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        # Computes stylized images.
        stylized_images, _, loss_dict, _ = build_model.build_model(
            eval_content_inputs_,
            eval_style_inputs_,
            trainable=False,
            is_training=False,
            reuse=True,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=True,
            content_weights=content_weights,
            style_weights=style_weights,
            total_variation_weight=FLAGS.total_variation_weight)

        # Adds Image summaries to the tensorboard.
        tf.summary.image(
            'image/{}/0_eval_content_inputs'.format(FLAGS.eval_name),
            eval_content_inputs_, 3)
        tf.summary.image(
            'image/{}/1_eval_style_inputs'.format(FLAGS.eval_name),
            eval_style_inputs_, 3)
        tf.summary.image(
            'image/{}/2_eval_stylized_images'.format(FLAGS.eval_name),
            stylized_images, 3)
        tf.summary.image('image/{}/3_stylized_noise'.format(FLAGS.eval_name),
                         stylized_noise, 3)

        metrics = {}
        for key, value in loss_dict.iteritems():
            metrics[key] = tf.metrics.mean(value)

        names_values, names_updates = slim.metrics.aggregate_metric_map(
            metrics)
        for name, value in names_values.iteritems():
            slim.summaries.add_scalar_summary(value, name, print_summary=True)
        eval_op = names_updates.values()
        num_evals = FLAGS.num_evaluation_styles / FLAGS.batch_size

        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.checkpoint_dir,
            logdir=FLAGS.eval_dir,
            eval_op=eval_op,
            num_evals=num_evals,
            eval_interval_secs=FLAGS.eval_interval_secs)
예제 #6
0
def main():
    tf.reset_default_graph()
    BATCH_SIZE = 1
    with tf.Graph().as_default(), tf.Session() as sess:
        # train.MonitoredTraining
        # Loads content images.

        #TODO load test images from a different test path. Do a small amount for simplicity
        # Make sure it's the same content, style images as the test set for the GAN

        # TODO check the path here.
        test_content_path = '../testset/content/'
        test_style_path = '../testset/style/'

        # TODO Fix paths?
        content_path = '/home/noah/magenta/data/coco0/'  #magenta/data/coco0'
        style_path = '/home/noah/magenta/data/painter/painter0/'
        content_inputs_ = data_utils_all.load_random_images(
            content_path, batch_size=BATCH_SIZE)
        style_inputs_ = data_utils_all.load_random_images(
            style_path, batch_size=BATCH_SIZE)

        # akhil_dummy_path = '/Users/akhiljalan/Documents/trainingdata_stylized_500/stylized500/'

        # todo insert content_path, style_path...
        # content_inputs_ = data_utils_all.load_random_images(akhil_dummy_path, batch_size=BATCH_SIZE)

        # Loads evaluation style images.
        # style_inputs_ = data_utils_all.load_random_images(akhil_dummy_path, batch_size=BATCH_SIZE)

        # Default style, content, and variation weights from magenta.

        content_weights = {"vgg_16/conv3": 1}
        style_weights = {
            "vgg_16/conv1": 0.5e-3,
            "vgg_16/conv2": 0.5e-3,
            "vgg_16/conv3": 0.5e-3,
            "vgg_16/conv4": 0.5e-3
        }
        total_variation_weight = 1e4

        stylized_images, total_loss_pass_1, loss_dict_pass_1, _ = build_model.build_model(
            content_inputs_,
            style_inputs_,
            reuse=tf.AUTO_REUSE,
            trainable=True,
            is_training=True,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=True,
            content_weights=content_weights,
            style_weights=style_weights,
            total_variation_weight=total_variation_weight)

        unstylized_images, total_loss_pass_2, loss_dict_pass_2 = None, 0.0, {}

        # unstylized_images, total_loss_pass_2, loss_dict_pass_2, _ = build_model.build_model(
        # 	stylized_images, #stylized as content
        # 	content_inputs_, #original content as style
        # 	reuse=True,
        # 	trainable=True,
        # 	is_training=True,
        # 	inception_end_point='Mixed_6e',
        # 	style_prediction_bottleneck=100,
        # 	adds_losses=True,
        # 	content_weights=content_weights,
        # 	style_weights=style_weights,
        # 	total_variation_weight=total_variation_weight)

        # Log all losses to tensorboard.
        loss_dict = {}
        loss_dict.update(loss_dict_pass_1)
        loss_dict.update(loss_dict_pass_2)
        for key, value in loss_dict.iteritems():
            tf.summary.scalar(key, value)

        # Log images to tensorboard
        tf.summary.image('image/0_content_inputs', content_inputs_, 3)
        # tf.summary.image('image/1_style_inputs_orig', style_inputs_orig_, 3)
        tf.summary.image('image/1_style_inputs', style_inputs_, 3)
        tf.summary.image('image/2_stylized_images', stylized_images, 3)
        # tf.summary.image('image/3_unstylized_images', unstylized_images, 3)

        # discrim_predictions = discriminator_network(content_inputs_)

        # Generate label tensors on the fly.

        real_labels = data_utils_all.gen_labels(is_real=True,
                                                batch_size=BATCH_SIZE)
        fake_labels = data_utils_all.gen_labels(is_real=False,
                                                batch_size=BATCH_SIZE)
        # discrim_loss = slim.losses.softmax_cross_entropy(discrim_predictions, real_labels)
        # gen_fooling_loss = slim.losses.softmax_cross_entropy(discrim_predictions, fake_labels)

        gen_optimizer = tf.train.AdamOptimizer(learning_rate=1e-2)
        gen_train_op = slim.learning.create_train_op(
            total_loss_pass_1,  #+ total_variation_weight * gen_fooling_loss, # + total_loss_pass_2? 
            gen_optimizer,
            summarize_gradients=False)

        # discr_optimizer = tf.train.AdamOptimizer(learning_rate=1e-2)
        # discr_train_op = slim.learning.create_train_op(
        # 	discrim_loss, # + total_loss_pass_2?
        # 	discr_optimizer,
        # 	summarize_gradients=True)

        # todo merge train ops

        # Get checkpoint files.
        # See above for the inception, vgg checkpoints.

        # TODO change checkpoint path...
        gen_checkpoint = '/home/noah/arbitrary_style_transfer/model.ckpt'
        # gen_checkpoint = '../../magenta/arbitrary_style_transfer/model.ckpt'
        # discrim_checkpoint = './logdir/model.ckpt-85'

        model_vars = slim.get_variables_to_restore()
        # No saved model yet!
        # discrim_var_names = [var for var in model_vars if 'discriminator' in var.name]
        # gen_var_names = [var for var in model_vars if 'beta' not in var.name]
        # gen_var_names = [var for var in model_vars if var.shape == ()]
        gen_var_names = model_vars  #
        # print(gen_var_names)

        # gen_assign_op, gen_feed_dict = slim.assign_from_checkpoint(gen_checkpoint, gen_var_names) #TODO change this...
        # discrim_assign_op, discrim_feed_dict = slim.assign_from_checkpoint(discrim_checkpoint,
        #                                            discrim_var_names)

        init_fn = slim.assign_from_checkpoint_fn(gen_checkpoint, gen_var_names)

        def init_assign_func(sess):
            # sess.run(gen_assign_op, gen_feed_dict)
            init_fn(sess)  #

        slim.learning.train(
            train_op=gen_train_op,  #todo replace with merged train op. 
            logdir='./logdir01/',
            number_of_steps=2,
            save_summaries_secs=1,
            save_interval_secs=1,
            init_fn=init_assign_func)
예제 #7
0
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # Forces all input processing onto CPU in order to reserve the GPU for the
        # forward inference and back-propagation.
        device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               worker_device=device)):
            # Loads content images.
            # content_inputs_, _ = image_utils.imagenet_inputs(FLAGS.batch_size,
            #                                                  FLAGS.image_size)
            #changed
            content_inputs_, _, temp = load_xy_pairs(
                fake_directory=FLAGS.content_dir,
                real_directory=FLAGS.content_dir,
                batch_size=FLAGS.batch_size,
                prob_of_real=1)

            # Loads style images.
            # [style_inputs_, _,
            #  style_inputs_orig_] = image_utils.arbitrary_style_image_inputs(
            #      FLAGS.style_dataset_file,
            #      batch_size=FLAGS.batch_size,
            #      image_size=FLAGS.image_size,
            #      shuffle=True,
            #      center_crop=FLAGS.center_crop,
            #      augment_style_images=FLAGS.augment_style_images,
            #      random_style_image_size=FLAGS.random_style_image_size)
            style_inputs_, _, temp2 = load_xy_pairs(
                fake_directory=FLAGS.style_dir,
                real_directory=FLAGS.style_dir,
                batch_size=FLAGS.batch_size,
                prob_of_real=1)

        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Process style and content weight flags.
            content_weights = ast.literal_eval(FLAGS.content_weights)
            style_weights = ast.literal_eval(FLAGS.style_weights)

            # Define the model
            stylized_images, total_loss, loss_dict, _ = build_model.build_model(
                content_inputs_,
                style_inputs_,
                trainable=True,
                is_training=True,
                inception_end_point='Mixed_6e',
                style_prediction_bottleneck=100,
                adds_losses=True,
                content_weights=content_weights,
                style_weights=style_weights,
                total_variation_weight=FLAGS.total_variation_weight)

            if tf.gfile.IsDirectory(FLAGS.checkpoint):
                checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
            else:
                checkpoint = FLAGS.checkpoint
                tf.logging.info(
                    'loading latest checkpoint file: {}'.format(checkpoint))

            init_fn = slim.assign_from_checkpoint_fn(
                checkpoint, slim.get_variables_to_restore())
            # sess.run([tf.local_variables_initializer()])
            # init_fn(sess)

            # Adding scalar summaries to the tensorboard.
            for key, value in loss_dict.iteritems():
                tf.summary.scalar(key, value)

            # Adding Image summaries to the tensorboard.
            tf.summary.image('image/0_content_inputs', content_inputs_, 3)
            # tf.summary.image('image/1_style_inputs_orig', style_inputs_orig_, 3)
            tf.summary.image('image/2_style_inputs_aug', style_inputs_, 3)
            tf.summary.image('image/3_stylized_images', stylized_images, 3)

            # Set up training
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            train_op = slim.learning.create_train_op(
                total_loss,
                optimizer,
                clip_gradient_norm=FLAGS.clip_gradient_norm,
                summarize_gradients=False)

            # Function to restore VGG16 parameters.
            init_fn_vgg = slim.assign_from_checkpoint_fn(
                vgg.checkpoint_file(), slim.get_variables('vgg_16'))

            # Function to restore Inception_v3 parameters.
            inception_variables_dict = {
                var.op.name: var
                for var in slim.get_model_variables('InceptionV3')
            }
            init_fn_inception = slim.assign_from_checkpoint_fn(
                FLAGS.inception_v3_checkpoint, inception_variables_dict)

            # Function to restore VGG16 and Inception_v3 parameters.
            def init_sub_networks(session):
                # init_fn_vgg(session)
                # init_fn_inception(session)
                # session.run(tf.local_variables_initializer())
                init_fn(session)

            # config = tf.ConfigProto()
            # config.gpu_options_allow_growth = True

            config = tf.ConfigProto()
            # config.gpu_options.allocator_type = 'BFC'
            # config.gpu_options.per_process_gpu_memory_fraction = 0.98

            # Run training
            slim.learning.train(session_config=config,
                                train_op=train_op,
                                logdir=os.path.expanduser(FLAGS.train_dir),
                                master=FLAGS.master,
                                is_chief=FLAGS.task == 0,
                                number_of_steps=FLAGS.train_steps,
                                init_fn=init_sub_networks,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
예제 #8
0
def run(content_images_paths,style_images_paths,style_square_crop=False,style_image_size=400,content_square_crop=False,
image_size=400,checkpoint='arbitrary_style_transfer/model.ckpt',maximum_styles_to_evaluate=1024,
interpolation_weights='[1.0]'):
  f = open('docs/accessKeys.csv', 'r')
  reader = csv.reader(f,delimiter=',')
  i = 1
  keys = []
  for row in reader:
      if i == 2:
          keys = row
      i = i + 1
  f.close()
  s3 = boto3.client('s3',
                          aws_access_key_id=keys[0],
                          aws_secret_access_key=keys[1])
  start = time.time()
  front = content_images_paths.split('.')[0]
  # bucket name, dir of file in s3 bucket, local name
  s3.download_file(config.bucketname, 'target1/'+content_images_paths, content_images_paths)
  convertimage(content_images_paths, False)
  content = 'c_'+front+'.jpg'

  s3.download_file(config.bucketname, 'style1/'+style_images_paths, style_images_paths)
  convertimage(style_images_paths, True)


  style = 's_'+front+'.jpg'
  # start evaluate
  tf.logging.set_verbosity(tf.logging.INFO)
  # if not tf.gfile.Exists(output_dir):
  #   tf.gfile.MkDir(output_dir)

  with tf.Graph().as_default(), tf.Session() as sess:
    # Defines place holder for the style image.
    style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
    if style_square_crop:
      style_img_preprocessed = image_utils.center_crop_resize_image(
          style_img_ph, style_image_size)
    else:
      style_img_preprocessed = image_utils.resize_image(style_img_ph,
                                                        style_image_size)

    # Defines place holder for the content image.
    content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
    if content_square_crop:
      content_img_preprocessed = image_utils.center_crop_resize_image(
          content_img_ph, image_size)
    else:
      content_img_preprocessed = image_utils.resize_image(
          content_img_ph, image_size)

    # Defines the model.
    stylized_images, _, _, bottleneck_feat = build_model.build_model(
        content_img_preprocessed,
        style_img_preprocessed,
        trainable=False,
        is_training=False,
        inception_end_point='Mixed_6e',
        style_prediction_bottleneck=100,
        adds_losses=False)

    if tf.gfile.IsDirectory(checkpoint):
      checkpoint = tf.train.latest_checkpoint(checkpoint)
    else:
      checkpoint = checkpoint
      tf.logging.info('loading latest checkpoint file: {}'.format(checkpoint))

    init_fn = slim.assign_from_checkpoint_fn(checkpoint,
                                             slim.get_variables_to_restore())
    sess.run([tf.local_variables_initializer()])
    init_fn(sess)

    # Gets the list of the input style images.
    style_img_list = tf.gfile.Glob(style)
    if len(style_img_list) > maximum_styles_to_evaluate:
      np.random.seed(1234)
      style_img_list = np.random.permutation(style_img_list)
      style_img_list = style_img_list[:maximum_styles_to_evaluate]

    # Gets list of input content images.
    content_img_list = tf.gfile.Glob(content)

    for content_i, content_img_path in enumerate(content_img_list):
      content_img_np = image_utils.load_np_image_uint8(content_img_path)[:, :, :
                                                                         3]
      content_img_name = os.path.basename(content_img_path)[:-4]

      # Saves preprocessed content image.
      inp_img_croped_resized_np = sess.run(
          content_img_preprocessed, feed_dict={
              content_img_ph: content_img_np
          })
      # image_utils.save_np_image(inp_img_croped_resized_np,
      #                           os.path.join(output_dir,
      #                                        '%s.jpg' % (content_img_name)))

      # Computes bottleneck features of the style prediction network for the
      # identity transform.
      identity_params = sess.run(
          bottleneck_feat, feed_dict={style_img_ph: content_img_np})

      for style_i, style_img_path in enumerate(style_img_list):
        if style_i > maximum_styles_to_evaluate:
          break
        style_img_name = os.path.basename(style_img_path)[:-4]
        style_image_np = image_utils.load_np_image_uint8(style_img_path)[:, :, :
                                                                         3]

        if style_i % 10 == 0:
          tf.logging.info('Stylizing (%d) %s with (%d) %s' %
                          (content_i, content_img_name, style_i,
                           style_img_name))

        # Saves preprocessed style image.
        style_img_croped_resized_np = sess.run(
            style_img_preprocessed, feed_dict={
                style_img_ph: style_image_np
            })
        # image_utils.save_np_image(style_img_croped_resized_np,
        #                           os.path.join(output_dir,
        #                                        '%s.jpg' % (style_img_name)))

        # Computes bottleneck features of the style prediction network for the
        # given style image.
        style_params = sess.run(
            bottleneck_feat, feed_dict={style_img_ph: style_image_np})

        interpolation_weights = ast.literal_eval(interpolation_weights)
        # Interpolates between the parameters of the identity transform and
        # style parameters of the given style image.
        for interp_i, wi in enumerate(interpolation_weights):
          stylized_image_res = sess.run(
              stylized_images,
              feed_dict={
                  bottleneck_feat:
                      identity_params * (1 - wi) + style_params * wi,
                  content_img_ph:
                      content_img_np
              })

          # Saves stylized image.
          image_utils.save_np_image(
              stylized_image_res,
              front+'.jpg')
  upload_file(front+'.jpg','styletransferimage','output1/'+front+'.jpg')
  print("finished upload")
  os.remove('c_'+front+'.jpg')
  os.remove('s_'+front+'.jpg')
  os.remove(front+'.jpg')
  status = False
  retryCounter = 2
  while retryCounter > 0:
      try:
          channel = grpc.insecure_channel('54.164.44.43:50051')
          stub = uid_management_pb2_grpc.UidManagementStub(channel)
          status = stub.TransferCompleted(uid_management_pb2.Id(id=front))
          retryCounter = 0
      except:
          retryCounter -= 1