def build_model(self):
        data_path = self.args.train_data_path

        imgs = load_data.get_loader(data_path, self.batch_size, self.img_size)

        style_imgs = load_style_img(self.args.style_data_path)

        with slim.arg_scope(model.arg_scope()):
            gen_img, variables = model.gen_net(imgs,
                                               reuse=False,
                                               name='transform')

            with slim.arg_scope(vgg.vgg_arg_scope()):
                gen_img_processed = [
                    load_data.img_process(image, True) for image in tf.unstack(
                        gen_img, axis=0, num=self.batch_size)
                ]

                f1, f2, f3, f4, exclude = vgg.vgg_16(
                    tf.concat([gen_img_processed, imgs, style_imgs], axis=0))

                gen_f, img_f, _ = tf.split(f3, 3, 0)
                content_loss = tf.nn.l2_loss(gen_f - img_f) / tf.to_float(
                    tf.size(gen_f))

                style_loss = model.styleloss(f1, f2, f3, f4)

                # load vgg model
                vgg_model_path = self.args.vgg_model
                vgg_vars = slim.get_variables_to_restore(include=['vgg_16'],
                                                         exclude=exclude)
                # vgg_init_var = slim.get_variables_to_restore(include=['vgg_16/fc6'])
                init_fn = slim.assign_from_checkpoint_fn(
                    vgg_model_path, vgg_vars)
                init_fn(self.sess)
                # tf.initialize_variables(var_list=vgg_init_var)
                print('vgg s weights load done')

            self.gen_img = gen_img

            self.global_step = tf.Variable(0,
                                           name="global_step",
                                           trainable=False)

            self.content_loss = content_loss
            self.style_loss = style_loss * self.args.style_w
            self.loss = self.content_loss + self.style_loss
            self.opt = tf.train.AdamOptimizer(0.0001).minimize(
                self.loss, global_step=self.global_step, var_list=variables)

        all_var = tf.global_variables()
        # init_var = [v for v in all_var if 'beta' in v.name or 'global_step' in v.name or 'Adam' in v.name]
        init_var = [v for v in all_var if 'vgg_16' not in v.name]
        init = tf.variables_initializer(var_list=init_var)
        self.sess.run(init)

        self.save = tf.train.Saver(var_list=variables)
def load_test_img(img_path):
    style_img = tf.read_file(img_path)

    style_img = tf.image.decode_jpeg(style_img, 3)
    shape = tf.shape(style_img)

    style_img = tf.image.resize_images(style_img, [shape[0], shape[1]])
    style_img = load_data.img_process(style_img, True)

    images = tf.expand_dims(style_img, 0)
    return images
def load_style_img(styleImgPath):
    img = tf.read_file(styleImgPath)
    style_img = tf.image.decode_jpeg(img, 3)

    style_img = tf.image.resize_images(style_img, [256, 256])

    style_img = load_data.img_process(style_img, True)  # True for substract means

    images = tf.expand_dims(style_img, 0)
    style_imgs = tf.concat([images, images, images, images], 0)  # batch is 4
    # style_imgs = tf.image.resize_images(style_imgs, [256, 256])

    return style_imgs