def main(_): console.suppress_logging() FLAGS.train = True FLAGS.overwrite = True # Start console.start('MNIST VANILLA VAE') # Get model model = models.vanilla('vanilla_00') if FLAGS.train: mnist = load_mnist('../../data/MNIST', flatten=True, validation_size=0, one_hot=True) model.train(training_set=mnist[pedia.training], epoch=1000, batch_size=128, print_cycle=50, snapshot_cycle=200) else: samples = model.generate(sample_num=16) console.show_status('{} samples generated'.format(samples.shape[0])) imtool.gan_grid_plot(samples, show=True) # End console.end()
def main(_): console.suppress_logging() FLAGS.train = True FLAGS.overwrite = False # Start console.start("MNIST DCGAN DEMO") # Get model model = models.dcgan('dcgan_002') # model = models.dcgan_h3_rs_nbn() # Train or test if FLAGS.train: mnist = load_mnist('../../data/MNIST', flatten=False, validation_size=0, one_hot=True) model.train(training_set=mnist[pedia.training], epoch=10, batch_size=128, print_cycle=20, snapshot_cycle=200, D_times=1, G_times=1) else: samples = model.generate(sample_num=16) console.show_status('{} samples generated'.format(samples.shape[0])) imtool.gan_grid_plot(samples, show=True) # End console.end()
def main(_): console.suppress_logging() FLAGS.overwrite = True FLAGS.train = True # Start console.start() # Get or define model model = models.vanilla('vanilla_nov9_02_h2_c', bn=False) # model = models.dcgan('dcgan_c00') # model = models.vanilla_h3_rs_nbn('vanilla_nov9_01_h3_nbn_opdef') # return # Train or test if FLAGS.train: mnist = load_mnist('../../data/MNIST', flatten=True, validation_size=0, one_hot=True) model.train(training_set=mnist[pedia.training], epoch=1000, batch_size=128, print_cycle=20, snapshot_cycle=150, sample_num=25) else: samples = model.generate(sample_num=16) console.show_status('{} samples generated'.format(samples.shape[0])) imtool.gan_grid_plot(samples, show=True) # End console.end()
def main(_): console.suppress_logging() FLAGS.train = True FLAGS.overwrite = False # Start console.start('CIFAR-10 DCGAN') # Get model model = models.dcgan('dcgan_00') if FLAGS.train: cifar10 = load_cifar10('../../data/CIFAR-10', validation_size=0, one_hot=True) model.train(training_set=cifar10[pedia.training], epoch=20000, batch_size=128, print_cycle=20, snapshot_cycle=2000) else: samples = model.generate(sample_num=16) console.show_status('{} samples generated'.format(samples.shape[0])) imtool.gan_grid_plot(samples, show=True) # End console.end()
def interpolate(self, z1=None, z2=None, inter_num=8, via='spherical'): z1 = self._random_z(1) if z1 is None else z1 z2 = self._random_z(1) if z2 is None else z2 zs = np.stack((z1,)*(inter_num + 2)) zs[-1] = z2 # Interpolate z if via in ['great_circle', 'circle', 'spherical']: interp = lambda mu: interpolations.slerp(mu, z1, z2) elif via in ['straight_line', 'line', 'linear']: interp = lambda mu: z1 + mu * (z2 - z1) else: raise ValueError("Can not resolve '{}'".format(via)) for i in range(inter_num): pct = 1.0 * (i + 1) / (inter_num + 1) zs[i+1] = interp(pct) # Generate samples samples = self.generate(zs) # Plot samples fig = imtool.gan_grid_plot(samples, h=1) return fig
def _default_snapshot_function(self): assert isinstance(self, VAE) z = self._random_z(self._sample_num) feed_dict = {} feed_dict[self.P.default_input_tensor] = z feed_dict.update(self._get_status_feed_dict(is_training=False)) samples = self._outputs.eval(feed_dict) # Plot samples fig = imtool.gan_grid_plot(samples) return fig
def _default_snapshot_function(self): assert isinstance(self, GAN) # Generate samples feed_dict = {} if self._conditional: z, one_hot = self._random_z(self._sample_num, True) feed_dict[self._targets] = one_hot else: z = self._random_z(self._sample_num) feed_dict[self.G.default_input_tensor] = z feed_dict.update(self._get_status_feed_dict(is_training=False)) samples = self._outputs.eval(feed_dict) # Plot samples fig = imtool.gan_grid_plot(samples) return fig