예제 #1
0
파일: dcgan.py 프로젝트: Orrimp/IDNNs
class DCGAN:
    def __init__(self, args):
        np.random.seed(int(time.time()))
        self.store = Store()
        self.visual = Visual(self.store)
        self.image_shape = [28, 28, 1]  # 28x28 pixels and black white
        self.batch_size = args.batch_size
        self.lr = args.learning_rate
        self.train_epoch = args.train_epoch
        self.dropout_keep_probability = tf.placeholder("float")

        self.mnist = input_data.read_data_sets("MNIST_data/",
                                               one_hot=True,
                                               reshape=[])
        self.is_training = tf.placeholder(dtype=tf.bool)

        self.x = tf.placeholder(tf.float32,
                                shape=(None, 64, 64, 1),
                                name="X_Input")
        self.z = tf.placeholder(tf.float32, shape=(None, 1, 1, 100), name="Z")

        self.G_z = define_generator(self.z, self.is_training)

        D_real, D_real_logits = define_discriminator(self.x, self.is_training)
        D_fake, D_fake_logits = define_discriminator(self.G_z,
                                                     self.is_training,
                                                     reuse=True)

        D_loss_real = init_loss(D_real_logits, tf.ones, self.batch_size)
        D_loss_fake = init_loss(D_fake_logits, tf.zeros, self.batch_size)
        self.G_loss = init_loss(D_fake_logits, tf.ones, self.batch_size)
        self.D_loss = D_loss_real + D_loss_fake

        self.sess = None

    def train(self):
        D_vars = [
            var for var in tf.trainable_variables()
            if var.name.startswith('discriminator')
        ]
        G_vars = [
            var for var in tf.trainable_variables()
            if var.name.startswith('generator')
        ]

        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5).minimize(
                self.D_loss, var_list=D_vars)
            G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5).minimize(
                self.G_loss, var_list=G_vars)

        # start the session with multi GPU support
        configuration = tf.ConfigProto(allow_soft_placement=True,
                                       log_device_placement=True)
        self.sess = tf.InteractiveSession(config=configuration)
        tf.global_variables_initializer().run()

        train_set = tf.image.resize_images(self.mnist.train.images,
                                           [64, 64]).eval()
        train_set = (train_set - 0.5) / 0.5  # normalization; range: -1 ~ 1

        print('training start!')
        start_time = time.time()

        for epoch in range(self.train_epoch):
            G_losses = []
            D_losses = []
            epoch_start_time = time.time()
            iterations = self.mnist.train.num_examples // self.batch_size
            for iter in range(iterations):
                # update discriminator
                x_ = train_set[iter * self.batch_size:(iter + 1) *
                               self.batch_size]
                z_ = np.random.normal(0, 1, (self.batch_size, 1, 1, 100))
                loss_d, _ = self.sess.run([self.D_loss, D_optim], {
                    self.x: x_,
                    self.z: z_,
                    self.is_training: True
                })

                # update generator
                z_ = np.random.normal(0, 1, (self.batch_size, 1, 1, 100))
                loss_g, _ = self.sess.run([self.G_loss, G_optim], {
                    self.z: z_,
                    self.x: x_,
                    self.is_training: True
                })

                G_losses.append(loss_g)
                D_losses.append(loss_d)
                print("Epoch " + str(epoch) + "/" + str(self.train_epoch) +
                      " of iteration " + str(iter) + "/" + str(iterations) +
                      " with loss_g " + str(loss_g) + " and loss_d " +
                      str(loss_d))

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time
            print('[%d/%d] - ptime: %.2f loss_d: %.3f, loss_g: %.3f' %
                  ((epoch + 1), self.train_epoch, per_epoch_ptime,
                   np.mean(D_losses), np.mean(G_losses)))
            fixed_p = root + 'Fixed_results/' + model + str(epoch + 1) + '.png'
            fixed_z_ = np.random.normal(0, 1, (25, 1, 1, 100))
            test_images = self.sess.run(self.G_z, {
                self.z: fixed_z_,
                self.is_training: False
            })
            self.visual.show_result(test_images,
                                    num_epoch=epoch,
                                    show=False,
                                    save=True,
                                    path=fixed_p)

            self.store.hist_append('D_losses', np.mean(D_losses))
            self.store.hist_append('G_losses', np.mean(G_losses))
            self.store.hist_append('per_epoch_ptimes', per_epoch_ptime)

        # let it run and save the images
        end_time = time.time()
        total_ptime = end_time - start_time
        self.store.hist_append('total_ptime', total_ptime)

        print('Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f' %
              (np.mean(self.store.retrieve('per_epoch_ptimes')),
               self.train_epoch, total_ptime))
        self.visual.show_train_hist(self.store.retrieve('D_losses'),
                                    self.store.retrieve('G_losses'),
                                    show=False,
                                    save=True,
                                    path=root + model + 'train_hist.png')

        images = []
        for e in range(self.train_epoch):
            img_name = root + 'Fixed_results/' + model + str(e + 1) + '.png'
            images.append(imageio.imread(img_name))
        imageio.mimsave(root + model + 'generation_animation.gif',
                        images,
                        fps=5)

    def visualize(self, layer, image_stimulation):
        '''Gives the model a layer and a image to check the activates against
        :param layer: Layer to visualize
        :param image_stimulation: Image to run through the models to see the activations
        :return: graph to display
        '''
        #https://medium.com/@awjuliani/visualizing-neural-network-layer-activation-tensorflow-tutorial-d45f8bf7bbc4
        flatten_image = np.reshape(image_stimulation, [1, 784], order='F')
        units = self.sess.run(layer,
                              feed_dict={
                                  self.x: flatten_image,
                                  self.dropout_keep_probability: 1.0
                              })
        self.visual.plotNNFilter(units)

    def shutdown(self):
        self.sess.close()