Ejemplo n.º 1
0
def main(_):
    console.suppress_logging()
    FLAGS.train = True
    FLAGS.overwrite = True

    # Start
    console.start('MNIST VANILLA VAE')

    # Get model
    model = models.vanilla('vanilla_00')

    if FLAGS.train:
        mnist = load_mnist('../../data/MNIST',
                           flatten=True,
                           validation_size=0,
                           one_hot=True)
        model.train(training_set=mnist[pedia.training],
                    epoch=1000,
                    batch_size=128,
                    print_cycle=50,
                    snapshot_cycle=200)
    else:
        samples = model.generate(sample_num=16)
        console.show_status('{} samples generated'.format(samples.shape[0]))
        imtool.gan_grid_plot(samples, show=True)

    # End
    console.end()
Ejemplo n.º 2
0
def main(_):
  console.suppress_logging()
  FLAGS.train = True
  FLAGS.overwrite = False


  # Start
  console.start("MNIST DCGAN DEMO")

  # Get model
  model = models.dcgan('dcgan_002')
  # model = models.dcgan_h3_rs_nbn()

  # Train or test
  if FLAGS.train:
    mnist = load_mnist('../../data/MNIST', flatten=False, validation_size=0,
                       one_hot=True)
    model.train(training_set=mnist[pedia.training], epoch=10, batch_size=128,
                print_cycle=20, snapshot_cycle=200, D_times=1, G_times=1)
  else:
    samples = model.generate(sample_num=16)
    console.show_status('{} samples generated'.format(samples.shape[0]))
    imtool.gan_grid_plot(samples, show=True)

  # End
  console.end()
Ejemplo n.º 3
0
def main(_):
    console.suppress_logging()
    FLAGS.overwrite = True
    FLAGS.train = True

    # Start
    console.start()

    # Get or define model
    model = models.vanilla('vanilla_nov9_02_h2_c', bn=False)
    # model = models.dcgan('dcgan_c00')
    # model = models.vanilla_h3_rs_nbn('vanilla_nov9_01_h3_nbn_opdef')
    # return

    # Train or test
    if FLAGS.train:
        mnist = load_mnist('../../data/MNIST',
                           flatten=True,
                           validation_size=0,
                           one_hot=True)
        model.train(training_set=mnist[pedia.training],
                    epoch=1000,
                    batch_size=128,
                    print_cycle=20,
                    snapshot_cycle=150,
                    sample_num=25)
    else:
        samples = model.generate(sample_num=16)
        console.show_status('{} samples generated'.format(samples.shape[0]))
        imtool.gan_grid_plot(samples, show=True)

    # End
    console.end()
Ejemplo n.º 4
0
def main(_):
    console.suppress_logging()
    FLAGS.train = True
    FLAGS.overwrite = False

    # Start
    console.start('CIFAR-10 DCGAN')

    # Get model
    model = models.dcgan('dcgan_00')

    if FLAGS.train:
        cifar10 = load_cifar10('../../data/CIFAR-10',
                               validation_size=0,
                               one_hot=True)
        model.train(training_set=cifar10[pedia.training],
                    epoch=20000,
                    batch_size=128,
                    print_cycle=20,
                    snapshot_cycle=2000)
    else:
        samples = model.generate(sample_num=16)
        console.show_status('{} samples generated'.format(samples.shape[0]))
        imtool.gan_grid_plot(samples, show=True)

    # End
    console.end()
Ejemplo n.º 5
0
  def interpolate(self, z1=None, z2=None, inter_num=8, via='spherical'):
    z1 = self._random_z(1) if z1 is None else z1
    z2 = self._random_z(1) if z2 is None else z2

    zs = np.stack((z1,)*(inter_num + 2))
    zs[-1] = z2

    # Interpolate z
    if via in ['great_circle', 'circle', 'spherical']:
      interp = lambda mu: interpolations.slerp(mu, z1, z2)
    elif via in ['straight_line', 'line', 'linear']:
      interp = lambda mu: z1 + mu * (z2 - z1)
    else:
      raise ValueError("Can not resolve '{}'".format(via))

    for i in range(inter_num):
      pct = 1.0 * (i + 1) / (inter_num + 1)
      zs[i+1] = interp(pct)

    # Generate samples
    samples = self.generate(zs)

    # Plot samples
    fig = imtool.gan_grid_plot(samples, h=1)

    return fig
Ejemplo n.º 6
0
    def _default_snapshot_function(self):
        assert isinstance(self, VAE)

        z = self._random_z(self._sample_num)

        feed_dict = {}
        feed_dict[self.P.default_input_tensor] = z
        feed_dict.update(self._get_status_feed_dict(is_training=False))
        samples = self._outputs.eval(feed_dict)

        # Plot samples
        fig = imtool.gan_grid_plot(samples)

        return fig
Ejemplo n.º 7
0
  def _default_snapshot_function(self):
    assert isinstance(self, GAN)

    # Generate samples
    feed_dict = {}
    if self._conditional:
      z, one_hot = self._random_z(self._sample_num, True)
      feed_dict[self._targets] = one_hot
    else:
      z = self._random_z(self._sample_num)

    feed_dict[self.G.default_input_tensor] = z
    feed_dict.update(self._get_status_feed_dict(is_training=False))
    samples = self._outputs.eval(feed_dict)

    # Plot samples
    fig = imtool.gan_grid_plot(samples)

    return fig