Beispiel #1
0
def main(argv = None):
    """
    tensorflow 中的app.run会先解析命令行参数flag,然后执行main函数
    """
    # 定义存放图像,标签和过拟合的数据结构
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32, shape=[IMAGE_SIZE, IMAGE_SIZE, 3], name="input_image")
    annotation = tf.placeholder(tf.int32, shape=[IMAGE_SIZE, IMAGE_SIZE, 1], name="annotation")

    pred_annotation, logits = segmentation(image, keep_probability)
    # 添加监控信息,可以通过tensorboard查看
    tf.summary.image("input_image", image, max_outputs = 2)
    tf.summary.image("ground_truth", tf.cast(annotation, tf.uint8), max_outputs = 2)
    tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs = 2)
    # 还没看懂loss函数是怎么计算的,reduce_mean是求平均值,squeeze为去掉维度是1
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                          labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                          name="entropy")))

    tf.summary.scalar("entropy", loss)
    # 获取在训练中的变量列表
    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    print("Setting up image reader...")

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    sess.run(tf.global_variables_initializer())
    name_list = shuffle_namelist(train_Path)
    start_index = 0
    batch_size = FLAGS.batch_size
    if FLAGS.mode == "train":
        for itr in range(MAX_ITERATION):
            train_images, train_annotations, start_index = read_2_namelist(name_list, batch_size, start_index)
            feed_dict = {image: train_images, annotation: train_annotations, keep_probability: 0}

            sess.run(train_op, feed_dict=feed_dict)

            if itr % 10 == 0:
                train_loss, summary_str = sess.run([loss, summary_op], feed_dict=feed_dict)
                print("Step: %d, Train_loss:%g" % (itr, train_loss))
                summary_writer.add_summary(summary_str, itr)
            if itr % 500 == 0:
                saver.save(sess, FLAGS.logs_dir + "FCN_model.ckpt", itr)
    if FLAGS.mode == "predict":
        saver.restore(sess, FLAGS.logs_dir + "FCN_model.ckpt")
        predict(sess, pred_annotation)
Beispiel #2
0
    def create_network(self, generator_dims, discriminator_dims, optimizer="Adam", learning_rate=2e-4,
                       optimizer_param=0.9, improved_gan_loss=True):
        print("Setting up model...")
        self._setup_placeholder()
        tf.summary.histogram("z", self.z_vec)
        self.gen_images = self._generator(self.z_vec, generator_dims, self.train_phase, scope_name="generator")

        tf.summary.image("image_real", self.images, max_outputs=2)
        tf.summary.image("image_generated", self.gen_images, max_outputs=2)

        def leaky_relu(x, name="leaky_relu"):
            return utils.leaky_relu(x, alpha=0.2, name=name)

        discriminator_real_prob, logits_real, feature_real = self._discriminator(self.images, discriminator_dims,
                                                                                 self.train_phase,
                                                                                 activation=leaky_relu,
                                                                                 scope_name="discriminator",
                                                                                 scope_reuse=False)

        discriminator_fake_prob, logits_fake, feature_fake = self._discriminator(self.gen_images, discriminator_dims,
                                                                                 self.train_phase,
                                                                                 activation=leaky_relu,
                                                                                 scope_name="discriminator",
                                                                                 scope_reuse=True)

        # utils.add_activation_summary(tf.identity(discriminator_real_prob, name='disc_real_prob'))
        # utils.add_activation_summary(tf.identity(discriminator_fake_prob, name='disc_fake_prob'))

        # Loss calculation
        self._gan_loss(logits_real, logits_fake, feature_real, feature_fake, use_features=improved_gan_loss)

        train_variables = tf.trainable_variables()

        for v in train_variables:
            # print (v.op.name)
            utils.add_to_regularization_and_summary(var=v)

        self.generator_variables = [v for v in train_variables if v.name.startswith("generator")]
        # print(map(lambda x: x.op.name, generator_variables))
        self.discriminator_variables = [v for v in train_variables if v.name.startswith("discriminator")]
        # print(map(lambda x: x.op.name, discriminator_variables))

        optim = self._get_optimizer(optimizer, learning_rate, optimizer_param)

        self.generator_train_op = self._train(self.gen_loss, self.generator_variables, optim)
        self.discriminator_train_op = self._train(self.discriminator_loss, self.discriminator_variables, optim)
Beispiel #3
0
    def create_network(self,
                       generator_dims,
                       discriminator_dims,
                       optimizer="Adam",
                       learning_rate=2e-4,
                       optimizer_param=0.9,
                       improved_gan_loss=True,
                       trainable_z=False,
                       trainable_image=False):
        print("Setting up model...")
        self._setup_placeholder()
        tf.summary.histogram("z", self.z_vec)

        if trainable_z:
            # make z iterator variable
            self.z_iterator = tf.Variable(np.random.uniform(
                -1.0, 1.0, (self.batch_size, int(
                    self.z_vec.get_shape()[1]))).astype(dtype=np.float32),
                                          name="z_iterator")
            self.init_z_iterator = tf.group(self.z_iterator.assign(self.z_vec))
            self.z_iterator_min_max = tf.maximum(
                tf.minimum(tf.maximum(self.z_iterator, -1.0), 1.0), -1.0)
            self.z_vec_in = self.z_iterator_min_max
        else:
            self.z_vec_in = self.z_vec

        # generator for training
        self.gen_images = self._generator(self.z_vec_in,
                                          generator_dims,
                                          self.train_phase,
                                          scope_name="generator")

        if trainable_image:
            # make image iterator variable
            self.image_iterator = tf.Variable(np.random.uniform(
                0.0, 1.0,
                (self.batch_size, 64, 64, 3)).astype(dtype=np.float32),
                                              name="image_iterator")
            self.init_image_iterator = tf.group(
                self.image_iterator.assign(self.gen_images))
            self.image_iterator_min_max = tf.minimum(
                tf.maximum(self.image_iterator, -1.0), 1.0)
            self.gen_images_out = self.image_iterator_min_max
        else:
            self.gen_images_out = self.gen_images

        # generator for z iterator

        tf.summary.image("image_real", self.images)
        tf.summary.image("image_generated", self.gen_images_out)

        def leaky_relu(x, name="leaky_relu"):
            return utils.leaky_relu(x, alpha=0.2, name=name)

        discriminator_real_prob, logits_real, feature_real = self._discriminator(
            self.images,
            discriminator_dims,
            self.train_phase,
            activation=leaky_relu,
            scope_name="discriminator",
            scope_reuse=False)

        discriminator_fake_prob, logits_fake, feature_fake = self._discriminator(
            self.gen_images_out,
            discriminator_dims,
            self.train_phase,
            activation=leaky_relu,
            scope_name="discriminator",
            scope_reuse=True)

        # utils.add_activation_summary(tf.identity(discriminator_real_prob, name='disc_real_prob'))
        # utils.add_activation_summary(tf.identity(discriminator_fake_prob, name='disc_fake_prob'))

        # Loss calculation
        self._gan_loss(logits_real,
                       logits_fake,
                       feature_real,
                       feature_fake,
                       use_features=improved_gan_loss)

        train_variables = tf.trainable_variables()

        for v in train_variables:
            utils.add_to_regularization_and_summary(var=v)

        # get variable lists for everything
        self.generator_variables = [
            v for v in train_variables if v.name.startswith("generator")
        ]
        self.discriminator_variables = [
            v for v in train_variables if v.name.startswith("discriminator")
        ]
        self.image_iterator_variables = [
            v for v in train_variables if v.name.startswith("image_iterator")
        ]
        self.z_iterator_variables = [
            v for v in train_variables if v.name.startswith("z_iterator")
        ]

        # set optimizer
        optim = self._get_optimizer(optimizer, learning_rate, optimizer_param)
        optim_z = self._get_optimizer(optimizer, learning_rate * 10.0,
                                      optimizer_param)

        # make train ops
        if not trainable_image and not trainable_z:
            self.generator_train_op = self._train(self.gen_loss,
                                                  self.generator_variables,
                                                  optim)
            self.discriminator_train_op = self._train(
                self.discriminator_loss, self.discriminator_variables, optim)
        if trainable_image:
            self.image_iterator_train_op = self._train(
                self.gen_loss, self.image_iterator_variables, optim)
        if trainable_z:
            self.z_iterator_train_op = self._train(self.gen_loss,
                                                   self.z_iterator_variables,
                                                   optim_z)