def inference(from_file_path, args):
    with tf.Graph().as_default(), tf.Session() as sess:
        alpha = args[0]

        encoder = Encoder()
        decoder = Decoder()

        content_input = tf.placeholder(tf.float32,
                                       shape=(1, None, None, 3),
                                       name='content_input')
        style_input = tf.placeholder(tf.float32,
                                     shape=(1, None, None, 3),
                                     name='style_input')

        # switch RGB to BGR
        content = tf.reverse(content_input, axis=[-1])
        style = tf.reverse(style_input, axis=[-1])
        # preprocess image
        content = encoder.preprocess(content)
        style = encoder.preprocess(style)

        # encode image
        # we should initial global variables before restore model
        enc_c_net = encoder.encode(content, 'content/')
        enc_s_net = encoder.encode(style, 'style/')

        # pass the encoded images to AdaIN
        target_features = AdaIN(enc_c_net.outputs,
                                enc_s_net.outputs,
                                alpha=alpha)

        # decode target features back to image
        dec_net = decoder.decode(target_features, prefix="decoder/")

        generated_img = dec_net.outputs

        # deprocess image
        generated_img = encoder.deprocess(generated_img)

        # switch BGR back to RGB
        generated_img = tf.reverse(generated_img, axis=[-1])

        # clip to 0..255
        generated_img = tf.clip_by_value(generated_img, 0.0, 255.0)

        sess.run(tf.global_variables_initializer())

        encoder.restore_model(sess, ENCODER_PATH, enc_c_net)
        encoder.restore_model(sess, ENCODER_PATH, enc_s_net)
        decoder.restore_model(sess, DECODER_PATH, dec_net)

        model_args = (sess, generated_img, content_input, style_input)
        if from_file_path:
            run_from_file_paths(model_args, args)
        else:
            return run_from_layers(model_args, args)
Exemple #2
0
    def init_session_handler(self):
        self.sess = tf.Session()

        encoder = Encoder()
        decoder = Decoder()

        self.content_input = tf.placeholder(tf.float32,
                                            shape=(1, None, None, 3),
                                            name='content_input')
        self.style_input = tf.placeholder(tf.float32,
                                          shape=(1, None, None, 3),
                                          name='style_input')

        # switch RGB to BGR
        content = tf.reverse(self.content_input, axis=[-1])
        style = tf.reverse(self.style_input, axis=[-1])
        # preprocess image
        content = encoder.preprocess(content)
        style = encoder.preprocess(style)

        # encode image
        # we should initial global variables before restore model
        enc_c_net = encoder.encode(content, 'content/')
        enc_s_net = encoder.encode(style, 'style/')

        # pass the encoded images to AdaIN
        target_features = transfer_util.AdaIN(enc_c_net.outputs,
                                              enc_s_net.outputs,
                                              alpha=alpha)

        # decode target features back to image
        dec_net = decoder.decode(target_features, prefix="decoder/")

        self.generated_img = dec_net.outputs

        # deprocess image
        self.generated_img = encoder.deprocess(self.generated_img)

        # switch BGR back to RGB
        self.generated_img = tf.reverse(self.generated_img, axis=[-1])

        # clip to 0..255
        self.generated_img = tf.clip_by_value(self.generated_img, 0.0, 255.0)
        self.sess.run(tf.global_variables_initializer())

        # sess.run(tf.global_variables_initializer())

        encoder.restore_model(self.sess, self.encode_path, enc_c_net)
        encoder.restore_model(self.sess, self.encode_path, enc_s_net)
        decoder.restore_model(self.sess, self.decode_path, dec_net)
Exemple #3
0
            style_layer_loss.append(l2_mean + l2_sigma)

        style_loss = tf.reduce_sum(style_layer_loss)

        style_weight = 2.0

        # compute the total loss
        loss = content_loss + style_weight * style_loss

        # Training step (Only train the decoder params)
        train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(
            loss, var_list=stylied_dec_net.all_params)

        sess.run(tf.global_variables_initializer())

        encoder.restore_model(sess, ENCODER_PATH, content_enc_net)
        encoder.restore_model(sess, ENCODER_PATH, style_enc_net)
        encoder.restore_model(sess, ENCODER_PATH, stylied_enc_net)

        # """Start Training"""
        step = 0
        n_batches = int(num_imgs // BATCH_SIZE)

        elapsed_time = datetime.now() - start_time
        print(
            '\nElapsed time for preprocessing before actually train the model: %s'
            % elapsed_time)

        print('Now begin to train the model...\n')
        start_time = datetime.now()
Exemple #4
0
        dec_net = decoder.decode(target_features, prefix="decoder/")

        generated_img = dec_net.outputs

        # deprocess image
        generated_img = encoder.deprocess(generated_img)

        # switch BGR back to RGB
        generated_img = tf.reverse(generated_img, axis=[-1])

        # clip to 0..255
        generated_img = tf.clip_by_value(generated_img, 0.0, 255.0)

        sess.run(tf.global_variables_initializer())

        encoder.restore_model(sess, ENCODER_PATH, enc_c_net)
        encoder.restore_model(sess, ENCODER_PATH, enc_s_net)
        decoder.restore_model(sess, DECODER_PATH, dec_net)

        start_time = datetime.now()
        image_count = 0
        for s in style_images:
            for c in content_images:
                image_count = image_count + 1
                # Load image from path and add one extra diamension to it.
                content_image = imread(os.path.join(content_path, c),
                                       mode='RGB')
                style_image = imread(os.path.join(style_path, s), mode='RGB')

                content_tensor = np.expand_dims(content_image, axis=0)
                style_tensor = np.expand_dims(style_image, axis=0)