def VAE_encode_dataset(model_path, dataset, factor_dim, dim_z, M1_hidden_layers_px, M1_hidden_layers_qz): batch_size = 100 VAE = Conv2dVariationalAutoencoder(dim_x=factor_dim, dim_z=dim_z, batch_size=batch_size, hidden_layers_px=M1_hidden_layers_px, hidden_layers_qz=M1_hidden_layers_qz) enc_x_mean = None enc_x_var = None with VAE.session: VAE.saver.restore(VAE.session, model_path) for x_batch in utils.feed_numpy(batch_size, dataset): if enc_x_mean is None: enc_x_mean, enc_x_var = VAE.encode(x_batch) else: tmp_mean, tmp_var = VAE.encode(x_batch) enc_x_mean = np.vstack([enc_x_mean, tmp_mean]) enc_x_var = np.vstack([enc_x_var, tmp_var]) print("Shape of enc_x_mean is ", np.shape(enc_x_mean)) print("Shape of enc_x_var is ", np.shape(enc_x_var)) return enc_x_mean, enc_x_var
def train( self, x, x_valid, epochs, #num_batches, save_path=None, print_every=1, learning_rate=3e-4, beta1=0.9, beta2=0.999, seed=31415, stop_iter=100, load_path=None, draw_img=1): self.num_examples = x.shape[0] # self.num_batches = num_batches assert self.num_examples % self.batch_size == 0, '#Examples % #Batches != 0' # self.batch_size = self.num_examples // self.num_batches # x_size = x.shape[0] # x_valid_size = x.shape[0] # x = np.reshape(x, [x_size, self.dim_x, 1, 1]) # x_valid = np.reshape(x_valid, [x_valid_size, self.dim_x, 1, 1]) ''' Session and Summary ''' if save_path is None: self.save_path = 'checkpoints/model_CONV2D-VAE_{}.cpkt'.format( time.strftime("%m-%d-%H%M%S", time.localtime())) else: self.save_path = save_path np.random.seed(seed) tf.set_random_seed(seed) with self.G.as_default(): self.optimiser = tf.train.AdamOptimizer( learning_rate=learning_rate, beta1=beta1, beta2=beta2) self.train_op = self.optimiser.minimize(self.cost) init = tf.global_variables_initializer( ) # tf.initialize_all_variables() self._test_vars = None with self.session as sess: sess.run(init) # 实际上并没有执行 if load_path == 'default': self.saver.restore(sess, self.save_path) elif load_path is not None: self.saver.restore(sess, load_path) training_cost = 0. best_eval_log_lik = -np.inf stop_counter = 0 for epoch in range(epochs): ''' Shuffle Data ''' np.random.shuffle(x) ''' Training ''' for x_batch in utils.feed_numpy(self.batch_size, x): training_result = sess.run([self.train_op, self.cost], feed_dict={self.x: x_batch}) training_cost = training_result[1] ''' Evaluation ''' stop_counter += 1 # 训练,更新不断跟新参数 if epoch % print_every == 0: test_vars = tf.get_collection( bookkeeper.GraphKeys.TEST_VARIABLES) if test_vars: if test_vars != self._test_vars: self._test_vars = list(test_vars) self._test_var_init_op = tf.initialize_variables( test_vars) self._test_var_init_op.run() # eval_log_lik, x_recon_eval = \ # sess.run([self.eval_log_lik, self.x_recon_eval], # feed_dict={self.x: x_valid}) eval_log_lik = 0 x_recon_eval = 0 valid_times = x_valid.shape[0] / self.batch_size for x_valid_batch in utils.feed_numpy( self.batch_size, x_valid): log_lik, recon_eval = sess.run( [self.eval_log_lik, self.x_recon_eval], feed_dict={self.x: x_valid_batch}) eval_log_lik += log_lik x_recon_eval += recon_eval eval_log_lik /= valid_times x_recon_eval /= valid_times if eval_log_lik > best_eval_log_lik: best_eval_log_lik = eval_log_lik self.saver.save(sess, self.save_path) stop_counter = 0 utils.print_metrics( epoch + 1, ['Training', 'cost', training_cost], ['Validation', 'log-likelihood', eval_log_lik]) ## 画图 # if draw_img > 0 and epoch % draw_img == 0: # # import matplotlib # matplotlib.use('Agg') # import pylab # import seaborn as sns # # five_random = np.random.random_integers(x_valid.shape[0], size = 5) # x_sample = x_valid[five_random] # x_recon_sample = x_recon_eval[five_random] # # sns.set_style('white') # f, axes = pylab.subplots(5, 2, figsize=(8,12)) # for i,row in enumerate(axes): # # row[0].imshow(x_sample[i].reshape(28, 28), vmin=0, vmax=1) # im = row[1].imshow(x_recon_sample[i].reshape(28, 28), vmin=0, vmax=1, # cmap=sns.light_palette((1.0, 0.4980, 0.0549), input="rgb", as_cmap=True)) # # pylab.setp([a.get_xticklabels() for a in row], visible=False) # pylab.setp([a.get_yticklabels() for a in row], visible=False) # # f.subplots_adjust(left=0.0, right=0.9, bottom=0.0, top=1.0) # cbar_ax = f.add_axes([0.9, 0.1, 0.04, 0.8]) # f.colorbar(im, cax=cbar_ax, use_gridspec=True) # # pylab.tight_layout() # pylab.savefig('img/recon-'+str(epoch)+'.png', format='png') # pylab.clf() # pylab.close('all') if stop_counter >= stop_iter: print('Stopping VAE training') print( 'No change in validation log-likelihood for {} iterations' .format(stop_iter)) print('Best validation log-likelihood: {}'.format( best_eval_log_lik)) print('Model saved in {}'.format(self.save_path)) break