コード例 #1
0
    def build_model(self):
        data_path = self.args.train_data_path

        imgs = load_data.get_loader(data_path, self.batch_size, self.img_size)

        style_imgs = load_style_img(self.args.style_data_path)

        with slim.arg_scope(model.arg_scope()):
            gen_img, variables = model.gen_net(imgs,
                                               reuse=False,
                                               name='transform')

            with slim.arg_scope(vgg.vgg_arg_scope()):
                gen_img_processed = [
                    load_data.img_process(image, True) for image in tf.unstack(
                        gen_img, axis=0, num=self.batch_size)
                ]

                f1, f2, f3, f4, exclude = vgg.vgg_16(
                    tf.concat([gen_img_processed, imgs, style_imgs], axis=0))

                gen_f, img_f, _ = tf.split(f3, 3, 0)
                content_loss = tf.nn.l2_loss(gen_f - img_f) / tf.to_float(
                    tf.size(gen_f))

                style_loss = model.styleloss(f1, f2, f3, f4)

                # load vgg model
                vgg_model_path = self.args.vgg_model
                vgg_vars = slim.get_variables_to_restore(include=['vgg_16'],
                                                         exclude=exclude)
                # vgg_init_var = slim.get_variables_to_restore(include=['vgg_16/fc6'])
                init_fn = slim.assign_from_checkpoint_fn(
                    vgg_model_path, vgg_vars)
                init_fn(self.sess)
                # tf.initialize_variables(var_list=vgg_init_var)
                print('vgg s weights load done')

            self.gen_img = gen_img

            self.global_step = tf.Variable(0,
                                           name="global_step",
                                           trainable=False)

            self.content_loss = content_loss
            self.style_loss = style_loss * self.args.style_w
            self.loss = self.content_loss + self.style_loss
            self.opt = tf.train.AdamOptimizer(0.0001).minimize(
                self.loss, global_step=self.global_step, var_list=variables)

        all_var = tf.global_variables()
        # init_var = [v for v in all_var if 'beta' in v.name or 'global_step' in v.name or 'Adam' in v.name]
        init_var = [v for v in all_var if 'vgg_16' not in v.name]
        init = tf.variables_initializer(var_list=init_var)
        self.sess.run(init)

        self.save = tf.train.Saver(var_list=variables)
コード例 #2
0
    def test(self):
        print('test model')
        test_img_path = self.args.test_data_path
        test_img = load_test_img(test_img_path)
        # test_img = tf.random_uniform(shape=(1, 500, 800, 3), minval=0, maxval=1.)
        test_img = self.sess.run(test_img)
        with slim.arg_scope(model.arg_scope()):

            gen_img, _ = model.gen_net(test_img, reuse=False, name='transform')

            # load model
            model_path = self.args.transfer_model

            vars = slim.get_variables_to_restore(include=['transform'])
            # vgg_init_var = slim.get_variables_to_restore(include=['vgg_16/fc6'])
            init_fn = slim.assign_from_checkpoint_fn(model_path, vars)
            init_fn(self.sess)
            # tf.initialize_variables(var_list=vgg_init_var)
            print('vgg s weights load done')

            gen_img = self.sess.run(gen_img)
            save_img.save_images(gen_img, self.args.new_img_name)