Beispiel #1
0
    def visual(self):

        init = tf.initialize_all_variables()
        with tf.Session() as sess:
            sess.run(init)

            self.saver.restore(sess, self.model_path)

            realbatch_array, real_labels = self.data_ob.getNext_batch(0)
            batch_z = np.random.uniform(-1,
                                        1,
                                        size=[self.batch_size, self.z_dim])
            # visualize the weights 1 or you can change weights_2 .
            conv_weights = sess.run([tf.get_collection('weight_2')])
            vis_square(self.vi_path,
                       conv_weights[0][0].transpose(3, 0, 1, 2),
                       type=1)

            # visualize the activation 1
            ac = sess.run(
                [tf.get_collection('ac_2')],
                feed_dict={
                    self.images: realbatch_array[:64],
                    self.z: batch_z,
                    self.y: sample_label()
                })

            vis_square(self.vi_path, ac[0][0].transpose(3, 1, 2, 0), type=0)

            print("the visualization finish!")
Beispiel #2
0
    def test(self):

        init = tf.initialize_all_variables()

        with tf.Session() as sess:
            sess.run(init)

            self.saver.restore(sess, self.model_path)
            sample_z = np.random.uniform(1,
                                         -1,
                                         size=[self.batch_size, self.z_dim])

            output = sess.run(self.fake_images,
                              feed_dict={
                                  self.z: sample_z,
                                  self.y: sample_label()
                              })

            save_images(
                output, [8, 8],
                './{}/test{:02d}_{:04d}.png'.format(self.sample_dir, 0, 0))

            image = cv2.imread(
                './{}/test{:02d}_{:04d}.png'.format(self.sample_dir, 0, 0), 0)

            cv2.imshow("test", image)

            cv2.waitKey(-1)

            print("Test finish!")
    def test(self):

        init = tf.initialize_all_variables()

        with tf.Session() as sess:
            sess.run(init)

            self.saver.restore(sess, self.model_path)
            sample_z = np.random.uniform(1, -1, size=[self.batch_size, self.z_dim])

            output = sess.run(self.fake_images, feed_dict={self.z: sample_z, self.y: sample_label()})

            cv2.imshow("test", output[0])

            cv2.waitKey(-1)

            print("Test finish!")
Beispiel #4
0
    def train(self):

        opti_D = tf.train.AdamOptimizer(learning_rate=self.learning_rate_dis,
                                        beta1=0.5).minimize(
                                            self.loss, var_list=self.d_vars)
        opti_G = tf.train.AdamOptimizer(learning_rate=self.learning_rate_gen,
                                        beta1=0.5).minimize(
                                            self.G_fake_loss,
                                            var_list=self.g_vars)

        init = tf.global_variables_initializer()

        with tf.Session() as sess:

            sess.run(init)

            summary_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)

            #self.saver.restore(sess , self.model_path)

            batch_num = 0
            e = 0
            step = 0

            while e <= self.max_epoch:

                rand = np.random.randint(0, 100)
                rand = 0

                while batch_num < len(self.ds_train) / self.batch_size:

                    step = step + 1
                    realbatch_array, real_y = MnistData.getNextBatch(
                        self.ds_train, self.label_y, rand, batch_num,
                        self.batch_size)

                    batch_z = np.random.normal(
                        0, 1, size=[self.batch_size, self.sample_size])

                    #optimization D
                    _, summary_str = sess.run(
                        [opti_D, summary_op],
                        feed_dict={
                            self.images: realbatch_array,
                            self.z: batch_z,
                            self.y: real_y
                        })
                    summary_writer.add_summary(summary_str, step)
                    #optimizaiton G
                    _, summary_str = sess.run(
                        [opti_G, summary_op],
                        feed_dict={
                            self.images: realbatch_array,
                            self.z: batch_z,
                            self.y: real_y
                        })
                    summary_writer.add_summary(summary_str, step)
                    batch_num += 1

                    if step % 1 == 0:

                        D_loss = sess.run(self.loss,
                                          feed_dict={
                                              self.images: realbatch_array,
                                              self.z: batch_z,
                                              self.y: real_y
                                          })
                        fake_loss = sess.run(self.G_fake_loss,
                                             feed_dict={
                                                 self.z: batch_z,
                                                 self.y: real_y
                                             })
                        print(
                            "EPOCH %d step %d: D: loss = %.7f G: loss=%.7f " %
                            (e, step, D_loss, fake_loss))

                    if np.mod(step, 50) == 1:

                        sample_images = sess.run(self.fake_images,
                                                 feed_dict={
                                                     self.z: batch_z,
                                                     self.y: sample_label()
                                                 })
                        save_images(
                            sample_images[0:64], [8, 8],
                            './{}/train_{:02d}_{:04d}.png'.format(
                                self.sample_path, e, step))
                        #Save the model
                        self.saver.save(sess, self.model_path)

                e += 1
                batch_num = 0

            save_path = self.saver.save(sess, self.model_path)
            print("Model saved in file: %s" % save_path)
Beispiel #5
0
def dcgan(operation,
          data_name,
          output_size,
          sample_path,
          log_dir,
          model_path,
          visua_path,
          sample_num=64):

    if data_name == "mnist":

        print("you use the mnist dataset")

        data_array, data_y = load_mnist(data_name)

        sample_z = np.random.uniform(-1, 1, size=[sample_num, 100])

        y = tf.placeholder(tf.float32, [None, y_dim])

        images = tf.placeholder(
            tf.float32, [batch_size, output_size, output_size, channel])

        z = tf.placeholder(tf.float32, [None, sample_size])
        z_sum = tf.summary.histogram("z", z)

        fake_images = gern_net(batch_size, z, y, output_size)
        G_image = tf.summary.image("G_out", fake_images)

        sample_img = sample_net(sample_num, z, y, output_size)

        ##the loss of gerenate network
        D_pro, D_logits = dis_net(images, y, weights, biases, False)
        D_pro_sum = tf.summary.histogram("D_pro", D_pro)

        G_pro, G_logits = dis_net(fake_images, y, weights, biases, True)
        G_pro_sum = tf.summary.histogram("G_pro", G_pro)

        D_fake_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.zeros_like(G_pro), logits=G_logits))
        real_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(D_pro),
                                                    logits=D_logits))
        G_fake_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(G_pro),
                                                    logits=G_logits))

        loss = real_loss + D_fake_loss

        loss_sum = tf.summary.scalar("D_loss", loss)
        G_loss_sum = tf.summary.scalar("G_loss", G_fake_loss)

        merged_summary_op_d = tf.summary.merge([loss_sum, D_pro_sum])
        merged_summary_op_g = tf.summary.merge(
            [G_loss_sum, G_pro_sum, G_image, z_sum])

        t_vars = tf.trainable_variables()

        d_var = [var for var in t_vars if 'dis' in var.name]
        g_var = [var for var in t_vars if 'gen' in var.name]

        saver = tf.train.Saver()

        #if train
        if operation == 0:

            opti_D = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                            beta1=0.5).minimize(loss,
                                                                var_list=d_var)
            opti_G = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                            beta1=0.5).minimize(G_fake_loss,
                                                                var_list=g_var)

            init = tf.global_variables_initializer()

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True

            with tf.Session(config=config) as sess:

                sess.run(init)
                summary_writer = tf.summary.FileWriter(log_dir,
                                                       graph=sess.graph)
                batch_num = 0
                e = 0
                step = 0

                while e <= EPOCH:
                    data_array, data_y = shuffle_data(data_array, data_y)
                    while batch_num < len(data_array) / batch_size:

                        step = step + 1

                        realbatch_array, real_labels = getNext_batch(
                            data_array, data_y, batch_num)

                        #Get the z

                        batch_z = np.random.uniform(
                            -1, 1, size=[batch_size, sample_size])
                        #batch_z = np.random.normal(0 , 0.2 , size=[batch_size , sample_size])

                        _, summary_str = sess.run(
                            [opti_D, merged_summary_op_d],
                            feed_dict={
                                images: realbatch_array,
                                z: batch_z,
                                y: real_labels
                            })
                        summary_writer.add_summary(summary_str, step)

                        _, summary_str = sess.run(
                            [opti_G, merged_summary_op_g],
                            feed_dict={
                                z: batch_z,
                                y: real_labels
                            })
                        summary_writer.add_summary(summary_str, step)

                        batch_num += 1
                        # average_loss += loss_value

                        if step % display_step == 0:

                            D_loss = sess.run(loss,
                                              feed_dict={
                                                  images: realbatch_array,
                                                  z: batch_z,
                                                  y: real_labels
                                              })
                            fake_loss = sess.run(G_fake_loss,
                                                 feed_dict={
                                                     z: batch_z,
                                                     y: real_labels
                                                 })
                            print(
                                "EPOCH %d step %d: D: loss = %.7f G: loss=%.7f "
                                % (e, step, D_loss, fake_loss))

                        if np.mod(step, 50) == 1:

                            print("sample!")
                            sample_images = sess.run(sample_img,
                                                     feed_dict={
                                                         z: sample_z,
                                                         y: sample_label()
                                                     })
                            save_images(
                                sample_images, [8, 8],
                                './{}/train_{:02d}_{:04d}.png'.format(
                                    sample_path, e, step))
                            save_path = saver.save(sess, model_path)

                    e = e + 1
                    batch_num = 0

                save_path = saver.save(sess, model_path)
                print "Model saved in file: %s" % save_path

        #test

        elif operation == 1:

            print("Test")

            init = tf.initialize_all_variables()

            with tf.Session() as sess:

                sess.run(init)

                saver.restore(sess, model_path)
                sample_z = np.random.uniform(1, -1, size=[sample_num, 100])

                output = sess.run(sample_img,
                                  feed_dict={
                                      z: sample_z,
                                      y: sample_label()
                                  })

                save_images(
                    output, [8, 8],
                    './{}/test{:02d}_{:04d}.png'.format(sample_path, 0, 0))

                image = cv2.imread(
                    './{}/test{:02d}_{:04d}.png'.format(sample_path, 0, 0), 0)

                cv2.imshow("test", image)

                cv2.waitKey(-1)

                print("Test finish!")

        #visualize
        else:

            print("Visualize")

            init = tf.initialize_all_variables()
            with tf.Session() as sess:

                sess.run(init)

                saver.restore(sess, model_path)

                # visualize the weights 1 or you can change weights_2 .
                conv_weights = sess.run([tf.get_collection('weight_2')])

                vis_square(visua_path,
                           conv_weights[0][0].transpose(3, 0, 1, 2),
                           type=1)

                # visualize the activation 1
                ac = sess.run([tf.get_collection('ac_2')],
                              feed_dict={
                                  images: data_array[:64],
                                  z: sample_z,
                                  y: sample_label()
                              })

                vis_square(visua_path, ac[0][0].transpose(3, 1, 2, 0), type=0)

                print("the visualization finish!")

    else:
        print("other dataset!")
Beispiel #6
0
    def train(self):

        opti_D = tf.train.AdamOptimizer(learning_rate=self.learn_rate,
                                        beta1=0.5).minimize(
                                            self.D_loss, var_list=self.d_var)
        opti_G = tf.train.AdamOptimizer(learning_rate=self.learn_rate,
                                        beta1=0.5).minimize(
                                            self.G_loss, var_list=self.g_var)
        init = tf.global_variables_initializer()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            summary_writer = tf.summary.FileWriter(self.log_dir,
                                                   graph=sess.graph)

            step = 0
            while step <= 10000:

                realbatch_array, real_labels = self.data_ob.getNext_batch(step)

                # Get the z
                batch_z = np.random.uniform(-1,
                                            1,
                                            size=[self.batch_size, self.z_dim])
                # batch_z = np.random.normal(0 , 0.2 , size=[batch_size , sample_size])

                _, summary_str = sess.run(
                    [opti_D, self.merged_summary_op_d],
                    feed_dict={
                        self.images: realbatch_array,
                        self.z: batch_z,
                        self.y: real_labels
                    })
                summary_writer.add_summary(summary_str, step)

                _, summary_str = sess.run([opti_G, self.merged_summary_op_g],
                                          feed_dict={
                                              self.z: batch_z,
                                              self.y: real_labels
                                          })
                summary_writer.add_summary(summary_str, step)

                if step % 50 == 0:

                    D_loss = sess.run(self.D_loss,
                                      feed_dict={
                                          self.images: realbatch_array,
                                          self.z: batch_z,
                                          self.y: real_labels
                                      })
                    fake_loss = sess.run(self.G_loss,
                                         feed_dict={
                                             self.z: batch_z,
                                             self.y: real_labels
                                         })
                    print("Step %d: D: loss = %.7f G: loss=%.7f " %
                          (step, D_loss, fake_loss))

                if np.mod(step, 50) == 1 and step != 0:

                    sample_images = sess.run(self.fake_images,
                                             feed_dict={
                                                 self.z: batch_z,
                                                 self.y: sample_label()
                                             })
                    save_images(
                        sample_images, [8, 8],
                        './{}/train_{:04d}.png'.format(self.sample_dir, step))

                    self.saver.save(sess, self.model_path)

                step = step + 1

            save_path = self.saver.save(sess, self.model_path)
            print("Model saved in file: %s" % save_path)
Beispiel #7
0
    def train(self, args):

        opti_D = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \
                          .minimize(self.d_loss, var_list=self.d_vars)
        opti_G = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \
                          .minimize(self.g_loss, var_list=self.g_vars)
        opti_C = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \
                          .minimize(self.d_c_loss, var_list=self.c_vars)

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            sess.run(init)
            if self.load:
                # load pretrained model
                print('loading:')
                self.saver = tf.train.import_meta_graph(
                    './model/model.ckpt-20201.meta'
                )  # default to save all variable
                self.saver.restore(sess,
                                   tf.train.latest_checkpoint('./model/'))

            self.writer = tf.summary.FileWriter("./logs", sess.graph)
            summary_writer = tf.summary.FileWriter(self.log_dir,
                                                   graph=sess.graph)

            step = 0
            while step <= self.training_step:
                realbatch_array, real_lungs, real_mediastinums, realmasks, real_labels = self.data_ob.getNext_batch(
                    step, batch_size=self.batch_size)
                batch_z = np.random.uniform(-1,
                                            1,
                                            size=[self.batch_size, self.z_dim])
                sess.run(
                    [opti_D],
                    feed_dict={
                        self.images: realbatch_array,
                        self.lungwindow: real_lungs,
                        self.mediastinumwindow: real_mediastinums,
                        self.masks: realmasks,
                        self.z: batch_z,
                        self.y: real_labels
                    })
                sess.run(
                    [opti_G],
                    feed_dict={
                        self.images: realbatch_array,
                        self.lungwindow: real_lungs,
                        self.mediastinumwindow: real_mediastinums,
                        self.masks: realmasks,
                        self.z: batch_z,
                        self.y: real_labels
                    })
                sess.run(
                    [opti_C],
                    feed_dict={
                        self.images: realbatch_array,
                        self.lungwindow: real_lungs,
                        self.mediastinumwindow: real_mediastinums,
                        self.masks: realmasks,
                        self.z: batch_z,
                        self.y: real_labels
                    })

                if np.mod(step, 50) == 1 and step != 0:
                    print('Saving...')
                    sample_images, lungwindow, mediastinumwindow = sess.run(
                        [
                            self.fake_B, self.fake_lungwindow,
                            self.fake_mediastinumwindow
                        ],
                        feed_dict={
                            self.images: realbatch_array,
                            self.lungwindow: real_lungs,
                            self.mediastinumwindow: real_mediastinums,
                            self.masks: realmasks,
                            self.z: batch_z,
                            self.y: real_labels
                        })
                    save_images(
                        sample_images, [8, 8],
                        './{}/{:04d}_sample.png'.format(self.train_dir, step))
                    save_images(
                        lungwindow, [8, 8],
                        './{}/{:04d}_lung.png'.format(self.train_dir, step))
                    save_images(
                        mediastinumwindow, [8, 8],
                        './{}/{:04d}_mediastinum.png'.format(
                            self.train_dir, step))
                    save_images(
                        realmasks, [8, 8],
                        './{}/{:04d}_mask.png'.format(self.train_dir, step))

                    print('save eval image')

                    real_labels = sample_label()
                    realmasks = sample_masks()
                    sample_images, lungwindow, mediastinumwindow = sess.run(
                        [
                            self.fake_B, self.fake_lungwindow,
                            self.fake_mediastinumwindow
                        ],
                        feed_dict={
                            self.masks: realmasks,
                            self.y: real_labels
                        })
                    save_images(
                        sample_images, [8, 8],
                        './{}/{:04d}_sample.png'.format(self.eval_dir, step))
                    save_images(
                        lungwindow, [8, 8],
                        './{}/{:04d}_lung.png'.format(self.eval_dir, step))
                    save_images(
                        mediastinumwindow, [8, 8],
                        './{}/{:04d}_mediastinum.png'.format(
                            self.eval_dir, step))
                    save_images(
                        realmasks, [8, 8],
                        './{}/{:04d}_mask.png'.format(self.eval_dir, step))

                    print('save test image')
                    real_labels = sample_label()
                    realmasks = sample_masks_test()
                    sample_images, lungwindow, mediastinumwindow = sess.run(
                        [
                            self.fake_B, self.fake_lungwindow,
                            self.fake_mediastinumwindow
                        ],
                        feed_dict={
                            self.masks: realmasks,
                            self.y: real_labels
                        })
                    save_images(
                        sample_images, [8, 8],
                        './{}/{:04d}_sample.png'.format(self.test_dir, step))
                    save_images(
                        lungwindow, [8, 8],
                        './{}/{:04d}_lung.png'.format(self.test_dir, step))
                    save_images(
                        mediastinumwindow, [8, 8],
                        './{}/{:04d}_mediastinum.png'.format(
                            self.test_dir, step))
                    save_images(
                        realmasks, [8, 8],
                        './{}/{:04d}_mask.png'.format(self.test_dir, step))
                    # save model each 50 epochs
                    self.saver.save(sess, self.model_path, global_step=step)

                step = step + 1

            save_path = self.saver.save(sess, self.model_path)
            print("Model saved in file: %s" % save_path)
def dcgan(operation,
          data_name,
          output_size,
          sample_path,
          log_dir,
          model_path,
          visua_path,
          sample_num=64):
    global data_array, data_y
    if data_name == "mnist":
        data_array, data_y = load_mnist(data_name)
        print("mnist")
    elif data_name == "celebA":
        print("celebA")
        data = glob(os.path.join("./data", "img_align_celeba", "*.jpg"))
        sample_files = data[0:64]
        sample = [
            get_image_celebA(sample_file,
                             input_height=108,
                             input_width=108,
                             resize_height=28,
                             resize_width=28,
                             is_crop=False,
                             is_grayscale=False)
            for sample_file in sample_files
        ]
        #sample = tf.reshape(sample, [-1, 28, 28, 1])
        data_array = np.array(sample).astype(np.float32)
        data_y = np.zeros(len(data_array))
    else:
        print("other dataset!")

    #print(len(data_array))
    #print("++++++++++++++++++++++++++++++++++++++++++++++")

    sample_z = np.random.uniform(-1, 1, size=[sample_num, 100])

    y = tf.placeholder(tf.float32, [None, y_dim])
    z = tf.placeholder(tf.float32, [None, sample_size])
    images = tf.placeholder(tf.float32,
                            [batch_size, output_size, output_size, channel])

    fake_images = gern_net(batch_size, z, y, output_size)
    sample_img = sample_net(sample_num, z, y, output_size)
    """
    the loss of gerenate network 
    tf.zeros_like, tf.ones_like生成0和1的矩阵
    discriminator: real images are labelled as 1
    discriminator: images from generator (fake) are labelled as 0
    generator: try to make the the fake images look real (1)
    sigmoid_cross_entropy_with_logits:可以对比1和(x,y)经过sigmoid后得出的概率,这里扩充多维。(某某分布属于标签1(0)的概率)
    """
    D_pro, D_logits = dis_net(images, y, weights, biases, False)
    G_pro, G_logits = dis_net(fake_images, y, weights, biases, True)
    D_real_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(D_pro),
                                                logits=D_logits))
    D_fake_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(G_pro),
                                                logits=G_logits))
    # 判别器的loss,能分真和假 --> ones_like(D_pro) 和 zeros_like(G_pro)
    D_loss = D_real_loss + D_fake_loss
    # 生成器的loss,能生成逼真的图片 --> ones_like(G_pro)
    G_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(G_pro),
                                                logits=G_logits))
    """
    公式表示的是:
    对于判别器D : max --> E(log(D(x))) + E(log(1 - D(G(z))))  代码是 ones_like(D_pro) + zeros_like(G_pro)
    对于判别器D : min -->                E(log(1 - D(G(z))))  等价于 max --> log(D(G(Z))) 代码是ones_like(G_pro)
    2014GAN论文原话:We train D to maximize the probability of assigning the correct label to 
    both training examples and samples from G. We simultaneously train G to minimize log(1-D(G(Z))).
    D是max能正确区分标签的概率【max --> E(log(D(x))) + E(log(1 - D(G(z))))】,也就要使loss在训练不断最小,(这里的Max和Min代表的含义是不一样的)
    同时也让G尽量去混淆它,也就是min E(log(1 - D(G(z)))) 用 max log(D(G(Z))) 替代。
    ....
    Rather than training G to minimize log(1 - D(G(z))), we can train G to maximize log(D(G(Z)))
    """
    """
    tf.summary.histogram, tf.summary.scalar可视化显示
    merge 合并在一起显示
    """
    z_sum = tf.summary.histogram("z", z)
    G_image = tf.summary.image("G_out", fake_images)
    D_pro_sum = tf.summary.histogram("D_pro", D_pro)
    G_pro_sum = tf.summary.histogram("G_pro", G_pro)
    loss_sum = tf.summary.scalar("D_loss", D_loss)
    G_loss_sum = tf.summary.scalar("G_loss", G_loss)
    merged_summary_op_d = tf.summary.merge([loss_sum, D_pro_sum])
    merged_summary_op_g = tf.summary.merge(
        [G_loss_sum, G_pro_sum, G_image, z_sum])

    t_vars = tf.trainable_variables()
    d_var = [var for var in t_vars if 'dis' in var.name]
    g_var = [var for var in t_vars if 'gen' in var.name]

    #定义保存模型变量
    saver = tf.train.Saver()
    #if train
    if operation == 0:
        opti_D = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                        beta1=0.5).minimize(D_loss,
                                                            var_list=d_var)
        opti_G = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                        beta1=0.5).minimize(G_loss,
                                                            var_list=g_var)
        init = tf.global_variables_initializer()  # 这句要在所有变量之后
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            sess.run(init)
            summary_writer = tf.summary.FileWriter(log_dir, graph=sess.graph)
            batch_num = 0  #每次一批,就是控制一批一批的范围
            step = 0  #每个batch的步长
            e = 0  #epoch个数
            while e <= EPOCH:
                #rand = np.random.randint(0, 100)
                rand = 0
                while batch_num < len(data_array) / batch_size:
                    step = step + 1
                    realbatch_array, real_labels = getNext_batch(
                        rand, data_array, data_y, batch_num)
                    batch_z = np.random.uniform(-1,
                                                1,
                                                size=[batch_size, sample_size])
                    # batch_z = np.random.normal(0, 0.2, size=[batch_size, sample_size])
                    _, summary_str_D = sess.run([opti_D, merged_summary_op_d],
                                                feed_dict={
                                                    images: realbatch_array,
                                                    z: batch_z,
                                                    y: real_labels
                                                })
                    _, summary_str_G = sess.run([opti_G, merged_summary_op_g],
                                                feed_dict={
                                                    z: batch_z,
                                                    y: real_labels
                                                })
                    batch_num += 1
                    # average_loss += loss_value
                    """
                    写日志和打印必要信息
                    """
                    summary_writer.add_summary(summary_str_D, step)
                    summary_writer.add_summary(summary_str_G, step)
                    if step % display_step == 0:
                        D_loss_result = sess.run(D_loss,
                                                 feed_dict={
                                                     images: realbatch_array,
                                                     z: batch_z,
                                                     y: real_labels
                                                 })
                        G_loss_result = sess.run(G_loss,
                                                 feed_dict={
                                                     z: batch_z,
                                                     y: real_labels
                                                 })
                        print(
                            "EPOCH %d step %d: D: loss = %.7f G: loss=%.7f " %
                            (e, step, D_loss_result, G_loss_result))
                    if np.mod(step, 50) == 1:
                        sample_images = sess.run(sample_img,
                                                 feed_dict={
                                                     z: sample_z,
                                                     y: sample_label()
                                                 })
                        save_images(
                            sample_images, [8, 8],
                            './{}/train_{:02d}_{:04d}.png'.format(
                                sample_path, e, step))
                        #save_path = saver.save(sess, model_path)
                e = e + 1
                batch_num = 0
            save_path = saver.save(sess, model_path)
            print("Model saved in file: %s" % save_path)

    #test
    elif operation == 1:
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init)
            saver.restore(sess, model_path)
            sample_z = np.random.uniform(1, -1, size=[sample_num, 100])
            output = sess.run(sample_img,
                              feed_dict={
                                  z: sample_z,
                                  y: sample_label()
                              })
            save_images(output, [8, 8],
                        './{}/test{:02d}_{:04d}.png'.format(sample_path, 0, 0))

            image = cv2.imread(
                './{}/test{:02d}_{:04d}.png'.format(sample_path, 0, 0), 0)
            cv2.imshow("test", image)
            cv2.waitKey(-1)
            print('./{}/test{:02d}_{:04d}.png'.format(sample_path, 0, 0))
            print("Test finish!")

    #visualize
    else:
        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init)
            saver.restore(sess, model_path)

            # visualize the weights 1 or you can change weights_2 .
            conv_weights = sess.run([tf.get_collection('weight_2')])
            vis_square(visua_path,
                       conv_weights[0][0].transpose(3, 0, 1, 2),
                       type=1)

            # visualize the activation 1
            ac = sess.run([tf.get_collection('ac_2')],
                          feed_dict={
                              images: data_array[:64],
                              z: sample_z,
                              y: sample_label()
                          })
            vis_square(visua_path, ac[0][0].transpose(3, 1, 2, 0), type=0)