def build_model(self):
        data_path = self.args.train_data_path

        imgs = load_data.get_loader(data_path, self.batch_size, self.img_size)

        style_imgs = load_style_img(self.args.style_data_path)

        with slim.arg_scope(model.arg_scope()):
            gen_img, variables = model.gen_net(imgs,
                                               reuse=False,
                                               name='transform')

            with slim.arg_scope(vgg.vgg_arg_scope()):
                gen_img_processed = [
                    load_data.img_process(image, True) for image in tf.unstack(
                        gen_img, axis=0, num=self.batch_size)
                ]

                f1, f2, f3, f4, exclude = vgg.vgg_16(
                    tf.concat([gen_img_processed, imgs, style_imgs], axis=0))

                gen_f, img_f, _ = tf.split(f3, 3, 0)
                content_loss = tf.nn.l2_loss(gen_f - img_f) / tf.to_float(
                    tf.size(gen_f))

                style_loss = model.styleloss(f1, f2, f3, f4)

                # load vgg model
                vgg_model_path = self.args.vgg_model
                vgg_vars = slim.get_variables_to_restore(include=['vgg_16'],
                                                         exclude=exclude)
                # vgg_init_var = slim.get_variables_to_restore(include=['vgg_16/fc6'])
                init_fn = slim.assign_from_checkpoint_fn(
                    vgg_model_path, vgg_vars)
                init_fn(self.sess)
                # tf.initialize_variables(var_list=vgg_init_var)
                print('vgg s weights load done')

            self.gen_img = gen_img

            self.global_step = tf.Variable(0,
                                           name="global_step",
                                           trainable=False)

            self.content_loss = content_loss
            self.style_loss = style_loss * self.args.style_w
            self.loss = self.content_loss + self.style_loss
            self.opt = tf.train.AdamOptimizer(0.0001).minimize(
                self.loss, global_step=self.global_step, var_list=variables)

        all_var = tf.global_variables()
        # init_var = [v for v in all_var if 'beta' in v.name or 'global_step' in v.name or 'Adam' in v.name]
        init_var = [v for v in all_var if 'vgg_16' not in v.name]
        init = tf.variables_initializer(var_list=init_var)
        self.sess.run(init)

        self.save = tf.train.Saver(var_list=variables)
Exemple #2
0
    def test(self):
        print ('test model')
        test_img = tools.load_test_img(TEST_IMAGE_PATH)
        test_img = self.sess.run(test_img)
        with slim.arg_scope(model.arg_scope()):
            gen_img, _ = model.inference(test_img, reuse=False, name='transform')

            vars = slim.get_variables_to_restore(include=['transform'])
            init_fn = slim.assign_from_checkpoint_fn('tmp/model/model.ckpt-50', vars)
            init_fn(self.sess)

            gen_img = self.sess.run(gen_img)
            tools.save_images(gen_img, 'images/test.jpg')
Exemple #3
0
    def _build_model(self):
        config, params = self.config, self.params

        is_training = params["is_training"]
        learning_rate = params["learning_rate"]
        global_step = params["global_step"]

        artifact_im = params["artifact_im"]
        reference_im = params["reference_im"]

        # TODO: more elegant way? (such as factory pattern)
        if config.model == "base":
            model_fn = model.base
        elif config.model == "residual":
            model_fn = model.residual
        elif config.model == "base-skip":
            model_fn = model.base_skip
        elif config.model == "residual-skip":
            model_fn = model.residual_skip
        else:
            raise NotImplementedError("There is no such {} model".format(
                config.model))

        with slim.arg_scope(model.arg_scope(is_training)):
            G_dn, G_residual, end_pts = model_fn(
                artifact_im,
                scope="generator",
                num_channels=config.num_channels)

        num_params = 0
        for var in tf.trainable_variables():
            print(var.name, var.get_shape())
            num_params += reduce(mul, var.get_shape().as_list(), 1)

        print("Total parameter: {}".format(num_params))

        with tf.variable_scope("Loss"):
            R_residual = artifact_im - reference_im
            # multiple 1000 to compare L2 loss in GAN
            L2_loss = 1000 * tf.losses.mean_squared_error(
                labels=R_residual, predictions=G_residual)

        with tf.variable_scope("Optimizer"):
            optimizer = tf.train.AdamOptimizer(
                learning_rate, beta1=config.beta1,
                epsilon=config.epsilon).minimize(L2_loss, global_step)

        params["denoised"] = G_dn
        params["residual"] = G_residual
        params["L2_loss"] = L2_loss
        params["optimizer"] = optimizer
    def test(self):
        print('test model')
        test_img_path = self.args.test_data_path
        test_img = load_test_img(test_img_path)
        # test_img = tf.random_uniform(shape=(1, 500, 800, 3), minval=0, maxval=1.)
        test_img = self.sess.run(test_img)
        with slim.arg_scope(model.arg_scope()):

            gen_img, _ = model.gen_net(test_img, reuse=False, name='transform')

            # load model
            model_path = self.args.transfer_model

            vars = slim.get_variables_to_restore(include=['transform'])
            # vgg_init_var = slim.get_variables_to_restore(include=['vgg_16/fc6'])
            init_fn = slim.assign_from_checkpoint_fn(model_path, vars)
            init_fn(self.sess)
            # tf.initialize_variables(var_list=vgg_init_var)
            print('vgg s weights load done')

            gen_img = self.sess.run(gen_img)
            save_img.save_images(gen_img, self.args.new_img_name)
Exemple #5
0
    def build_model(self):
        train_imgs = tools.load_train_img(TRAIN_DATA_DIR, self.batch_size, self.img_size)
        style_imgs = tools.load_style_img(STYLE_IMAGE_PATH)

        with slim.arg_scope(model.arg_scope()):
            gen_img, variables = model.inference(train_imgs, reuse=False, name='transform')

            with slim.arg_scope(vgg.vgg_arg_scope()):
                gen_img_processed = [tf.image.per_image_standardization(image) for image in
                                     tf.unstack(gen_img, axis=0, num=self.batch_size)]

                f1, f2, f3, f4, exclude = vgg.vgg_16(tf.concat([gen_img_processed, train_imgs, style_imgs], axis=0))

                gen_f, img_f, _ = tf.split(f4, 3, 0)
                content_loss = tf.nn.l2_loss(gen_f - img_f) / tf.to_float(tf.size(gen_f))

                style_loss = model.styleloss(f1, f2, f3, f4)

                vgg_model_path = VGG_MODEL_PATH
                vgg_vars = slim.get_variables_to_restore(include=['vgg_16'], exclude=exclude)
                init_fn = slim.assign_from_checkpoint_fn(vgg_model_path, vgg_vars)
                init_fn(self.sess)
                print("vgg's weights load done")

            self.gen_img = gen_img
            self.global_step = tf.Variable(0, name="global_step", trainable=False)
            self.content_loss = content_loss
            self.style_loss = style_loss * self.style_w
            self.loss = self.content_loss + self.style_loss
            self.learn_rate = tf.train.exponential_decay(self.learn_rate_base, self.global_step, 1,
                                                         self.learn_rate_decay, staircase=True)
            self.opt = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss, global_step=self.global_step,
                                                                        var_list=variables)

        all_var = tf.global_variables()
        init_var = [v for v in all_var if 'vgg_16' not in v.name]
        init = tf.variables_initializer(var_list=init_var)
        self.sess.run(init)
        self.save = tf.train.Saver(var_list=variables)
Exemple #6
0
def main(_):

    import setproctitle
    setproctitle.setproctitle('python_train')

    is_training = tf.placeholder(tf.bool, shape=())
    left_in = tf.placeholder(tf.float32,
                             [None, FLAGS.patch_size, FLAGS.patch_size, 1],
                             name='left')
    righ_in = tf.placeholder(tf.float32,
                             [None, FLAGS.patch_size, FLAGS.patch_size, 1],
                             name='right')
    with tf.name_scope('similarity'):
        label = tf.placeholder(tf.int32, [
            None,
        ], name='label')

    arg_scope = model.arg_scope(weight_decay=FLAGS.weight_decay,
                                data_format=DATA_FORMAT)
    global_steps = tf.Variable(0, trainable=False)

    with tf.device(gpus):
        with tf.name_scope('model'):
            with slim.arg_scope(arg_scope):
                left_out = model.siamese_net(left_in, reuse=False)
                righ_out = model.siamese_net(righ_in, reuse=True)
                cls_out = model.cls_net(left_out, righ_out)

        loss = model.net_loss(cls_out, label)
        tf.summary.scalar('loss', loss)

        lr = FLAGS.learning_rate
        mm = FLAGS.momentum
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr,
                                               momentum=mm,
                                               use_nesterov=True)
        train_op = optimizer.minimize(loss=loss, global_step=global_steps)
        # 损失函数优化器的minimize()中global_step=global_steps能够提供global_step自动+1的操作
        # global_steps是所有epoch的steps全都累加起来

    variables_to_train = tf.trainable_variables()
    for var in variables_to_train:
        tf.summary.histogram(var.op.name, var)

    saver = tf.train.Saver(max_to_keep=500)

    # Create a session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = False
    config.gpu_options.per_process_gpu_memory_fraction = 0.6
    config.log_device_placement = False
    config.allow_soft_placement = True
    sess = tf.Session(config=config)

    # Add summary writers
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter('log/train', sess.graph)
    eval_writer = tf.summary.FileWriter('log/eval')

    # Init variables
    if not os.listdir(ckpt_dir):  # 文件夹为空
        init = tf.global_variables_initializer()
        # Run the init operation
        sess.run(init, {is_training: True})
    else:
        saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir))

    # This is just a dictionary including important parameters transferred to another function
    # These parameters are actually all parts, namely operations, of the graph, without being
    ops = {
        'is_training_pl': is_training,
        'left_in_pl': left_in,
        'righ_in_pl': righ_in,
        'label_pl': label,
        'loss': loss,
        'train_op': train_op,
        'merged': merged,
        'global_steps': global_steps
    }

    # train iter
    for i in range(FLAGS.train_iter):
        np.random.shuffle(DATA_PATHS)
        np.random.shuffle(EVAL_DATA_PATHS)

        for b in range(n_batches):
            step, train_loss = train_one_step(b,
                                              sess,
                                              ops,
                                              train_writer,
                                              is_training=True)
            eval_loss = eval_one_step(b,
                                      sess,
                                      ops,
                                      eval_writer,
                                      is_training=False)
            print "global step %d: train loss = %f; eval loss = %f" % (
                step, train_loss, eval_loss)

            saver.save(sess, "model_with_eval/model.ckpt", global_steps)