示例#1
0
    def train(self):

        tf.global_variables_initializer().run()

        # graph inputs for visualize training results
        self.sample_z = prior.gaussian(self.batch_size, self.z_dim)

        start_epoch, start_batch_id, counter = self.before_train()

        start_time = time.time()
        for epoch in range(start_epoch, self.epoch):

            # get batch data
            for idx in range(start_batch_id, self.num_batches):
                batch_images = self.data_X[idx * self.batch_size:(idx + 1) *
                                           self.batch_size]
                batch_z = prior.gaussian(self.batch_size, self.z_dim)

                # update autoencoder
                _, summary_str, elbo_loss, nll_loss, kl_loss = self.sess.run(
                    [
                        self.optim, self.merged_summary_op, self.loss,
                        self.neg_loglikelihood, self.KL_divergence
                    ],
                    feed_dict={
                        self.inputs: batch_images,
                        self.z: batch_z
                    })

                self.writer.add_summary(summary_str, counter)

                # display training status
                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, elbo_loss: %.8f, nll_loss: %.8f, kl_loss: %.8f" \
                      % (epoch, idx, self.num_batches, time.time() - start_time, elbo_loss, nll_loss, kl_loss))

                # save training results for every sample_point steps
                if np.mod(counter, self.sample_point) == 0:
                    samples = self.sess.run(self.fake_images,
                                            feed_dict={self.z: self.sample_z})
                    tot_num_samples = min(self.sample_num, self.batch_size)
                    manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
                    manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
                    save_images(samples[:manifold_h * manifold_w, :, :, :],
                                [manifold_h, manifold_w],
                                image_path=osp.join(
                                    check_folder(
                                        osp.join(check_folder(self.result_dir),
                                                 self.model_dir)),
                                    self.model_name +
                                    '_train{}_{}.png'.format(epoch, idx)))

            start_batch_id = 0

            # save model
            self.save(self.checkpoint_dir, counter)

            # show temporal results
            self.visualize_results(epoch)
示例#2
0
文件: base.py 项目: SSUHan/All-GANs
 def save(self, checkpoint_dir, step):
     checkpoint_dir = osp.join(
         check_folder(osp.join(check_folder(checkpoint_dir),
                               self.model_dir)), self.model_name)
     check_folder(checkpoint_dir)
     self.saver.save(self.sess,
                     osp.join(checkpoint_dir, self.model_name + ".model"),
                     global_step=step)
示例#3
0
    def visualize_results(self, epoch):
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
        """ random condition, random noise """

        z_sample = prior.gaussian(self.batch_size, self.z_dim)

        samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})

        save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
                    [image_frame_dim, image_frame_dim],
                    image_path=osp.join(
                        check_folder(
                            osp.join(check_folder(self.result_dir),
                                     self.model_dir)), self.model_name +
                        '_epoch{}_test_all_classes.png'.format(epoch)))
        """ learned manifold """
        if self.z_dim == 2:
            assert self.z_dim == 2

            z_tot = None
            id_tot = None
            for idx in range(0, 100):
                # randomly sampling
                _id = np.random.randint(0, self.num_batches)
                batch_images = self.data_X[_id * self.batch_size:(_id + 1) *
                                           self.batch_size]
                batch_labels = self.data_y[_id * self.batch_size:(_id + 1) *
                                           self.batch_size]

                z = self.sess.run(self.mu,
                                  feed_dict={self.inputs: batch_images})

                if idx == 0:
                    z_tot = z
                    id_tot = batch_labels
                else:
                    z_tot = np.concatenate((z_tot, z), axis=0)
                    id_tot = np.concatenate((id_tot, batch_labels), axis=0)

            save_scattered_image(
                z_tot,
                id_tot,
                -4,
                4,
                name=osp.join(
                    check_folder(
                        osp.join(check_folder(self.result_dir),
                                 self.model_dir)), self.model_name +
                    '_epoch{}_learned_manifold.png'.format(epoch)))
示例#4
0
文件: AE.py 项目: SSUHan/All-GANs
    def do_sample_point(self, epoch, idx):
        samples = self.sess.run(self.decode_images,
                                feed_dict={self.inputs: self.test_images})
        tot_num_samples = min(self.sample_num, self.batch_size)
        manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
        manifold_w = int(np.floor(np.sqrt(tot_num_samples)))

        save_images(self.test_images[:manifold_h * manifold_w, :, :, :],
                    [manifold_h, manifold_w],
                    image_path=osp.join(
                        check_folder(
                            osp.join(check_folder(self.result_dir),
                                     self.model_dir)),
                        self.model_name + '_origin.png'))

        save_images(samples[:manifold_h * manifold_w, :, :, :],
                    [manifold_h, manifold_w],
                    image_path=osp.join(
                        check_folder(
                            osp.join(check_folder(self.result_dir),
                                     self.model_dir)), self.model_name +
                        '_train{}_{}.png'.format(epoch, idx)))
示例#5
0
文件: ACGAN.py 项目: SSUHan/All-GANs
    def train(self):
        tf.global_variables_initializer().run()

        # graph inputs for visualize training results
        self.sample_z = np.random.uniform(-1,
                                          1,
                                          size=(self.batch_size, self.z_dim))
        self.test_codes = self.data_y[0:self.batch_size]

        start_epoch, start_batch_id, counter = self.before_train()

        # loop for epoch
        start_time = time.time()
        for epoch in range(start_epoch, self.epoch):

            # get batch data
            for idx in range(start_batch_id, self.num_batches):
                batch_images = self.data_X[idx * self.batch_size:(idx + 1) *
                                           self.batch_size]
                batch_codes = self.data_y[idx * self.batch_size:(idx + 1) *
                                          self.batch_size]
                batch_z = np.random.uniform(
                    -1, 1, [self.batch_size, self.z_dim]).astype(np.float32)

                # update D network
                _, smy_str_d, d_loss = self.sess.run(
                    [self.d_optim, self.d_smy, self.d_loss],
                    feed_dict={
                        self.inputs: batch_images,
                        self.y: batch_codes,
                        self.z: batch_z
                    })
                self.writer.add_summary(smy_str_d, counter)

                # update G & C network
                _, smy_str_g, g_loss, _, smy_str_c, c_loss = self.sess.run(
                    [
                        self.g_optim, self.g_smy, self.g_loss, self.c_optim,
                        self.c_smy, self.c_loss
                    ],
                    feed_dict={
                        self.inputs: batch_images,
                        self.y: batch_codes,
                        self.z: batch_z
                    })
                self.writer.add_summary(smy_str_g, counter)
                self.writer.add_summary(smy_str_c, counter)

                # display training status
                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))

                # save training results for every {self.sample_point} steps
                if np.mod(counter, self.sample_point) == 0:
                    samples = self.sess.run(self.fake_images,
                                            feed_dict={
                                                self.z: self.sample_z,
                                                self.y: self.test_codes
                                            })
                    tot_num_samples = min(self.sample_num, self.batch_size)
                    manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
                    manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
                    save_images(samples[:manifold_h * manifold_w, :, :, :],
                                [manifold_h, manifold_w],
                                image_path=osp.join(
                                    check_folder(
                                        osp.join(check_folder(self.result_dir),
                                                 self.model_dir)),
                                    self.model_name +
                                    '_train{}_{}.png'.format(epoch, idx)))

            start_batch_id = 0

            # save model
            self.save(self.checkpoint_dir, counter)
示例#6
0
    def train(self):
        # initalize all variables
        tf.global_variables_initializer().run()

        # graph inputs for visualize training results
        self.sample_z = np.random.uniform(-1,
                                          1,
                                          size=(self.batch_size, self.z_dim))
        self.test_labels = self.data_y[:self.batch_size]

        start_epoch, start_batch_id, counter = self.before_train()

        # loop for epoch
        start_time = time.time()
        for epoch in range(start_epoch, self.epoch):

            # get batch data
            for idx in range(start_batch_id, self.num_batches):
                batch_images = self.data_X[idx * self.batch_size:(idx + 1) *
                                           self.batch_size]
                batch_labels = self.data_y[idx * self.batch_size:(idx + 1) *
                                           self.batch_size]
                batch_z = np.random.uniform(
                    -1, 1, [self.batch_size, self.z_dim]).astype(np.float32)

                # update D network
                _, summary_str, d_loss = self.sess.run(
                    [self.d_optim, self.d_smy, self.d_loss],
                    feed_dict={
                        self.inputs: batch_images,
                        self.labels: batch_labels,
                        self.noises: batch_z
                    })

                self.writer.add_summary(summary_str, counter)

                # update G network
                _, summary_str, g_loss = self.sess.run(
                    [self.g_optim, self.g_smy, self.g_loss],
                    feed_dict={
                        self.labels: batch_labels,
                        self.noises: batch_z
                    })
                self.writer.add_summary(summary_str, counter)

                # display training status
                counter += 1
                print("Epoch: [{}] [{}/{}] time: {}, d_loss: {}, g_loss: {}".
                      format(epoch, idx, self.num_batches,
                             time.time() - start_time, d_loss, g_loss))

                # save training results for every 300 step
                if np.mod(counter, self.sample_point) == 0:
                    samples = self.sess.run(self.fake_images,
                                            feed_dict={
                                                self.noises: self.sample_z,
                                                self.labels: self.test_labels
                                            })
                    tot_num_samples = min(self.sample_num, self.batch_size)
                    manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
                    manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
                    save_images(samples[:manifold_h * manifold_w, :, :, :],
                                [manifold_h, manifold_w],
                                image_path=osp.join(
                                    check_folder(
                                        osp.join(check_folder(self.result_dir),
                                                 self.model_dir)),
                                    self.model_name +
                                    '_train{}_{}.png'.format(epoch, idx)))

            start_batch_id = 0

            # save model
            self.save(self.checkpoint_dir, counter)

            # TODO:show temporal results
            # self.visualize_results(epoch)

        # save model for final step
        self.save(self.checkpoint_dir, counter)