예제 #1
0
    def test(self):
        self.sess.run(self.train_init)
        self.saver.restore(sess=self.sess, save_path=self.logdir)
        z_mis = []
        while True:
            try:
                z_mi = self.sess.run(self.z_mi)
                z_mis.append(z_mi)
            except tf.errors.OutOfRangeError:
                break
        z_mis = np.concatenate(z_mis, axis=0)
        kde = gaussian_kde(z_mis.transpose())

        self.sess.run(self.test_init)
        elbo, nll, mis = [], [], []
        while True:
            try:
                e, n, z, lqzx = self.sess.run([self.elbo, self.nll, self.z[0], self.log_q_z_x])
                mi = lqzx - kde.logpdf(z.transpose())
                elbo.append(e)
                nll.append(n)
                mis.append(np.mean(mi))
            except tf.errors.OutOfRangeError:
                break
        logger.log('Test: all %.4f nll %.4f elbo %.4f mi %.4f' % (float(np.mean(nll) + np.mean(elbo)),
                                                                  float(np.mean(nll)), float(np.mean(elbo)),
                                                                  float(np.mean(mis))))
 def _log(self, it):
     if it % 10 == 0:
         loss, nll, elbo = self.sess.run([self.loss, self.nll, self.elbo])
         logger.log("Iteration %d: loss %.4f nll %.4f elbo %.4f" %
                    (it, loss, nll, elbo))
         self.summary_writer.add_summary(self.sess.run(self.train_summary),
                                         it)
예제 #3
0
 def _log(self, it):
     if it % 1000 == 0:
         loss, lld, vae, mi_z_u, qu, um, elbo, l1, l2, l3, l4, l5, logpz, logqzx, mi_z_u0, mi_z_u1, _ = self.sess.run([
             self.loss, self.lld, self.vae, self.mi_z_u, self.logqu, self.um, self.elbo, self.l1, self.l2, self.l3,
             self.l4, self.l5, self.logpz, self.logqzx, self.mi_z_u0, self.mi_z_u1, self.trainer])
         logger.log('It %d: loss %.4f lld %.4f vae %.4f mi_z_u %.4f logqu %.4f um %.4f elbo %.4f l1 %.2f l2 %.2f l3 %.2f l4 %.2f l5 %.2f logpz %.2f logqzx %.2f mizu0 %.2f mizu1 %.2f' % (
             it, loss, lld, vae, mi_z_u, -qu, um, elbo, l1, l2, l3, l4, l5, np.mean(logpz), np.mean(logqzx), mi_z_u0, mi_z_u1
         ))
         self.summary_writer.add_summary(self.sess.run(self.train_summary), it)
예제 #4
0
 def _log(self, it):
     if it % 1000 == 0:
         loss, nll, elbo, error = self.sess.run([self.loss, self.nll, self.elbo, self.error])
         logger.log("Iteration %d: loss %.4f nll %.4f elbo %.4f error %.4f" % (it, loss, nll, elbo, error))