def visualize(): FLAGS = get_args() plot_size = 20 valid_data = MNISTData('test', data_dir=DATA_PATH, shuffle=True, pf=preprocess_im, batch_dict_name=['im', 'label']) valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize) with tf.variable_scope('VAE') as scope: model = VAE(n_code=FLAGS.ncode, wd=0) model.create_train_model() with tf.variable_scope('VAE') as scope: scope.reuse_variables() valid_model = VAE(n_code=FLAGS.ncode, wd=0) valid_model.create_generate_model(b_size=400) visualizer = Visualizer(model, save_path=SAVE_PATH) generator = Generator(generate_model=valid_model, save_path=SAVE_PATH) z = distribution.interpolate(plot_size=plot_size) z = np.reshape(z, (plot_size * plot_size, 2)) sessconfig = tf.ConfigProto() sessconfig.gpu_options.allow_growth = True with tf.Session(config=sessconfig) as sess: saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) saver.restore(sess, '{}vae-epoch-{}'.format(SAVE_PATH, FLAGS.load)) visualizer.viz_2Dlatent_variable(sess, valid_data) generator.generate_samples(sess, plot_size=plot_size, z=z)
def train(): FLAGS = get_args() train_data = MNISTData('train', data_dir=DATA_PATH, shuffle=True, pf=preprocess_im, batch_dict_name=['im', 'label']) train_data.setup(epoch_val=0, batch_size=FLAGS.bsize) valid_data = MNISTData('test', data_dir=DATA_PATH, shuffle=True, pf=preprocess_im, batch_dict_name=['im', 'label']) valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize) with tf.variable_scope('VAE') as scope: model = VAE(n_code=FLAGS.ncode, wd=0) model.create_train_model() with tf.variable_scope('VAE') as scope: scope.reuse_variables() valid_model = VAE(n_code=FLAGS.ncode, wd=0) valid_model.create_generate_model(b_size=400) trainer = Trainer(model, valid_model, train_data, init_lr=FLAGS.lr, save_path=SAVE_PATH) if FLAGS.ncode == 2: z = distribution.interpolate(plot_size=20) z = np.reshape(z, (400, 2)) visualizer = Visualizer(model, save_path=SAVE_PATH) else: z = None generator = Generator(generate_model=valid_model, save_path=SAVE_PATH) sessconfig = tf.ConfigProto() sessconfig.gpu_options.allow_growth = True with tf.Session(config=sessconfig) as sess: writer = tf.summary.FileWriter(SAVE_PATH) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) writer.add_graph(sess.graph) for epoch_id in range(FLAGS.maxepoch): trainer.train_epoch(sess, summary_writer=writer) trainer.valid_epoch(sess, summary_writer=writer) if epoch_id % 10 == 0: saver.save(sess, '{}vae-epoch-{}'.format(SAVE_PATH, epoch_id)) if FLAGS.ncode == 2: generator.generate_samples(sess, plot_size=20, z=z, file_id=epoch_id) visualizer.viz_2Dlatent_variable(sess, valid_data, file_id=epoch_id)
def generate_samples(self, sess, plot_size, manifold=False, file_id=None): # if z is None: # gen_im = sess.run(self._generate_op) # else: n_samples = plot_size * plot_size label_indices = None if self._use_label: cur_r = 0 label_indices = [] cur_label = -1 while cur_r < plot_size: cur_label = cur_label + 1 if cur_label < self._n_labels - 1 else 0 row_label = np.ones(plot_size) * cur_label label_indices.extend(row_label) cur_r += 1 if manifold: if self._dist == 'gaussian': random_code = distribution.interpolate( plot_size=plot_size, interpolate_range=[-3, 3, -3, 3]) self.viz_samples(sess, random_code, plot_size, file_id=file_id) else: for mode_id in range(self._n_labels): random_code = distribution.interpolate_gm( plot_size=plot_size, interpolate_range=[-1., 1., -0.2, 0.2], mode_id=mode_id, n_mode=self._n_labels) self.viz_samples(sess, random_code, plot_size, file_id='{}_{}'.format(file_id, mode_id)) else: if self._dist == 'gaussian': random_code = distribution.diagonal_gaussian( n_samples, self._g_model.n_code, mean=0, var=1.0) else: random_code = distribution.gaussian_mixture( n_samples, n_dim=self._g_model.n_code, n_labels=self._n_labels, x_var=0.5, y_var=0.1, label_indices=label_indices) self.viz_samples(sess, random_code, plot_size, file_id=file_id)