Пример #1
0
 def fit(self, sess, local_):
     for _ in range(local_):
         x, _ = next_batch_(FLAGS.bz)
         sess.run(self.optim, {self.x: x})
     x, _ = next_batch_(FLAGS.bz * 5)
     return sess.run(
         [self.loss, self.rec_loss, self.kld_loss, self.fit_summary],
         {self.x: x})
Пример #2
0
 def fit(self, sess, local_):
     for _ in range(local_):
         x, _ = next_batch_(FLAGS.bz)
         sess.run(self.optim, {
             self.x: x,
             self.z: gaussian(FLAGS.bz, FLAGS.z_dim)
         })
     x, _ = next_batch_(FLAGS.bz * 5)
     return sess.run(
         [self.loss, self.loss_nll, self.loss_mmd, self.fit_summary], {
             self.x: x,
             self.z: gaussian(FLAGS.bz, FLAGS.z_dim)
         })
Пример #3
0
    def fit(self, sess, local_):
        for _ in range(local_):
            x_real, y = next_batch_(FLAGS.bz)
            one_hot_y = one_hot_(y, FLAGS.y_dim)
            sess.run(self.a_optim, {self.x: x_real})
            for _ in range(3):
                sess.run(self.d_optim, {self.x: x_real, self.y: one_hot_y, self.real_z: z_real_(y)})
            sess.run(self.g_optim, {self.x: x_real, self.y: one_hot_y})

        x_real, y = next_batch_(FLAGS.bz * 5)
        one_hot_y = one_hot_(y, FLAGS.y_dim)
        return sess.run([self.a_loss, self.g_loss, self.d_loss, self.fit_summary], {
            self.x: x_real, self.real_z: z_real_(y), self.y: one_hot_y})
Пример #4
0
    def fit(self, sess, local_):
        for _ in range(local_):
            x_real, _ = next_batch_(FLAGS.bz)
            sess.run(self.a_optim, {self.x: x_real})
            for _ in range(3):
                sess.run(self.d_optim, {
                    self.x: x_real,
                    self.real_z: z_real_(FLAGS.bz)
                })
            sess.run(self.g_optim, {self.x: x_real})

        x_real, _ = next_batch_(FLAGS.bz * 5)
        return sess.run(
            [self.a_loss, self.g_loss, self.d_loss, self.fit_summary], {
                self.x: x_real,
                self.real_z: z_real_(FLAGS.bz * 5)
            })
Пример #5
0
 def latent_z(self, sess, bz):
     x, y = next_batch_(bz)
     return sess.run(self.fake_z, {self.x: x}), y