def train_ddi_vae(dim_z, hidden_layers_px, hidden_layers_qz, save_path, learning_rate): ############################# ''' Experiment Parameters ''' ############################# # num_batches = 100 #Number of minibatches in a single epoch, num_examples % self.num_batches == 0 # dim_z = 50 #Dimensionality of latent variable (z) epochs = 1000 #Number of epochs through the full dataset # learning_rate = 3e-3 #Learning rate of ADAM l2_loss = 1e-6 #L2 Regularisation weight seed = 31415 #Seed for RNG #Neural Networks parameterising p(x|z), q(z|x) # hidden_layers_px = [ 600, 1000, 800, 500 ] # hidden_layers_qz = [ 600, 1000, 800, 500 ] #################### ''' Load Dataset ''' #################### # mnist_path = 'mnist/mnist_28.pkl.gz' # #Uses anglpy module from original paper (linked at top) to load the dataset # train_x, train_y, valid_x, valid_y, test_x, test_y = mnist.load_numpy(mnist_path, binarize_y=True) # # x_train, y_train = train_x.T, train_y.T # x_valid, y_valid = valid_x.T, valid_y.T # x_test, y_test = test_x.T, test_y.T x_train, y_train, x_valid, y_valid, x_test, y_test = load_dataset("/data/cdy/ykq/ddi_dataset/train_dataset") utils.print_metrics(['x_train', x_train.shape[0], x_train.shape[1]], ['y_train', y_train.shape[0], y_train.shape[1]], ['x_valid', x_valid.shape[0], x_valid.shape[1]], ['y_valid', y_valid.shape[0], y_valid.shape[1]], ['x_test', x_test.shape[0], x_test.shape[1]], ['y_test', y_test.shape[0], y_test.shape[1]], ) dim_x = x_train.shape[1] dim_y = y_train.shape[1] ###################################### ''' Train Variational Auto-Encoder ''' ###################################### VAE = Conv2dVariationalAutoencoder( dim_x = dim_x, dim_z = dim_z, hidden_layers_px = hidden_layers_px, hidden_layers_qz = hidden_layers_qz, l2_loss = l2_loss ) #draw_img uses pylab and seaborn to draw images of original vs. reconstruction #every n iterations (set to 0 to disable) VAE.train( x = x_train, x_valid = x_valid, epochs = epochs, save_path=save_path, learning_rate = learning_rate, seed = seed, stop_iter = 30, print_every = 10, draw_img = 0 )
def predict_labels(self, x_test, y_test): test_vars = tf.get_collection(bookkeeper.GraphKeys.TEST_VARIABLES) tf.initialize_variables(test_vars).run() x_test_mu = x_test[:, :self.dim_x] x_test_lsgms = x_test[:, self.dim_x:2 * self.dim_x] accuracy, cross_entropy, precision, recall = \ self.session.run([self.eval_accuracy, self.eval_cross_entropy, self.eval_precision, self.eval_recall], feed_dict={self.x_labelled_mu: x_test_mu, self.x_labelled_lsgms: x_test_lsgms, self.y_lab: y_test}) utils.print_metrics('X', ['Test', 'accuracy', accuracy], ['Test', 'cross-entropy', cross_entropy], ['Test', 'precision', precision], ['Test', 'recall', recall])
def train( self, x, x_valid, epochs, #num_batches, save_path=None, print_every=1, learning_rate=3e-4, beta1=0.9, beta2=0.999, seed=31415, stop_iter=100, load_path=None, draw_img=1): self.num_examples = x.shape[0] # self.num_batches = num_batches assert self.num_examples % self.batch_size == 0, '#Examples % #Batches != 0' # self.batch_size = self.num_examples // self.num_batches # x_size = x.shape[0] # x_valid_size = x.shape[0] # x = np.reshape(x, [x_size, self.dim_x, 1, 1]) # x_valid = np.reshape(x_valid, [x_valid_size, self.dim_x, 1, 1]) ''' Session and Summary ''' if save_path is None: self.save_path = 'checkpoints/model_CONV2D-VAE_{}.cpkt'.format( time.strftime("%m-%d-%H%M%S", time.localtime())) else: self.save_path = save_path np.random.seed(seed) tf.set_random_seed(seed) with self.G.as_default(): self.optimiser = tf.train.AdamOptimizer( learning_rate=learning_rate, beta1=beta1, beta2=beta2) self.train_op = self.optimiser.minimize(self.cost) init = tf.global_variables_initializer( ) # tf.initialize_all_variables() self._test_vars = None with self.session as sess: sess.run(init) # 实际上并没有执行 if load_path == 'default': self.saver.restore(sess, self.save_path) elif load_path is not None: self.saver.restore(sess, load_path) training_cost = 0. best_eval_log_lik = -np.inf stop_counter = 0 for epoch in range(epochs): ''' Shuffle Data ''' np.random.shuffle(x) ''' Training ''' for x_batch in utils.feed_numpy(self.batch_size, x): training_result = sess.run([self.train_op, self.cost], feed_dict={self.x: x_batch}) training_cost = training_result[1] ''' Evaluation ''' stop_counter += 1 # 训练,更新不断跟新参数 if epoch % print_every == 0: test_vars = tf.get_collection( bookkeeper.GraphKeys.TEST_VARIABLES) if test_vars: if test_vars != self._test_vars: self._test_vars = list(test_vars) self._test_var_init_op = tf.initialize_variables( test_vars) self._test_var_init_op.run() # eval_log_lik, x_recon_eval = \ # sess.run([self.eval_log_lik, self.x_recon_eval], # feed_dict={self.x: x_valid}) eval_log_lik = 0 x_recon_eval = 0 valid_times = x_valid.shape[0] / self.batch_size for x_valid_batch in utils.feed_numpy( self.batch_size, x_valid): log_lik, recon_eval = sess.run( [self.eval_log_lik, self.x_recon_eval], feed_dict={self.x: x_valid_batch}) eval_log_lik += log_lik x_recon_eval += recon_eval eval_log_lik /= valid_times x_recon_eval /= valid_times if eval_log_lik > best_eval_log_lik: best_eval_log_lik = eval_log_lik self.saver.save(sess, self.save_path) stop_counter = 0 utils.print_metrics( epoch + 1, ['Training', 'cost', training_cost], ['Validation', 'log-likelihood', eval_log_lik]) ## 画图 # if draw_img > 0 and epoch % draw_img == 0: # # import matplotlib # matplotlib.use('Agg') # import pylab # import seaborn as sns # # five_random = np.random.random_integers(x_valid.shape[0], size = 5) # x_sample = x_valid[five_random] # x_recon_sample = x_recon_eval[five_random] # # sns.set_style('white') # f, axes = pylab.subplots(5, 2, figsize=(8,12)) # for i,row in enumerate(axes): # # row[0].imshow(x_sample[i].reshape(28, 28), vmin=0, vmax=1) # im = row[1].imshow(x_recon_sample[i].reshape(28, 28), vmin=0, vmax=1, # cmap=sns.light_palette((1.0, 0.4980, 0.0549), input="rgb", as_cmap=True)) # # pylab.setp([a.get_xticklabels() for a in row], visible=False) # pylab.setp([a.get_yticklabels() for a in row], visible=False) # # f.subplots_adjust(left=0.0, right=0.9, bottom=0.0, top=1.0) # cbar_ax = f.add_axes([0.9, 0.1, 0.04, 0.8]) # f.colorbar(im, cax=cbar_ax, use_gridspec=True) # # pylab.tight_layout() # pylab.savefig('img/recon-'+str(epoch)+'.png', format='png') # pylab.clf() # pylab.close('all') if stop_counter >= stop_iter: print('Stopping VAE training') print( 'No change in validation log-likelihood for {} iterations' .format(stop_iter)) print('Best validation log-likelihood: {}'.format( best_eval_log_lik)) print('Model saved in {}'.format(self.save_path)) break
def train(self, x_labelled, y, x_unlabelled, epochs, x_valid, y_valid, print_every=1, learning_rate=3e-4, beta1=0.9, beta2=0.999, seed=31415, stop_iter=100, save_path=None, load_path=None): ''' Session and Summary ''' if save_path is None: self.save_path = 'checkpoints/model_CONV2D_GC_{}-{}-{}_{}.cpkt'.format( self.num_lab, learning_rate, self.batch_size, time.time()) else: self.save_path = save_path np.random.seed(seed) tf.set_random_seed(seed) with self.G.as_default(): self.optimiser = tf.train.AdamOptimizer( learning_rate=learning_rate, beta1=beta1, beta2=beta2) self.train_op = self.optimiser.minimize(self.cost) init = tf.initialize_all_variables() self._test_vars = None _data_labelled = np.hstack([x_labelled, y]) _data_unlabelled = x_unlabelled x_valid_mu, x_valid_lsgms = x_valid[:, :int( self.dim_x)], x_valid[:, int(self.dim_x):int(2 * self.dim_x)] with self.session as sess: sess.run(init) if load_path == 'default': self.saver.restore(sess, self.save_path) elif load_path is not None: self.saver.restore(sess, load_path) best_eval_accuracy = 0. best_train_accuracy = 0. stop_counter = 0 # print("****lab_clf", self.weights['lab_clf']) # print("****lab_recon_clf", self.weights['lab_recon_clf']) # print("****ulab_clf", self.weights['ulab_clf']) # print("****ulab_recon_clf", self.weights['ulab_recon_clf']) # print("****L_loss", self.weights['L_loss']) # print("****U_loss", self.weights['U_loss']) for epoch in range(epochs): ''' Shuffle Data ''' np.random.shuffle(_data_labelled) np.random.shuffle(_data_unlabelled) ''' Training ''' for x_l_mu, x_l_lsgms, y, x_u_mu, x_u_lsgms in utils.feed_numpy_semisupervised( self.num_lab_batch, self.num_ulab_batch, _data_labelled[:, :2 * self.dim_x], _data_labelled[:, 2 * self.dim_x:], _data_unlabelled): training_result = sess.run( [self.train_op, self.cost], feed_dict={ self.x_labelled_mu: x_l_mu, self.x_labelled_lsgms: x_l_lsgms, self.y_lab: y, self.x_unlabelled_mu: x_u_mu, self.x_unlabelled_lsgms: x_u_lsgms }) training_cost = training_result[1] ''' Evaluation ''' stop_counter += 1 if epoch % print_every == 0: test_vars = tf.get_collection( bookkeeper.GraphKeys.TEST_VARIABLES) if test_vars: if test_vars != self._test_vars: self._test_vars = list(test_vars) self._test_var_init_op = tf.initialize_variables( test_vars) self._test_var_init_op.run() # eval_accuracy, eval_cross_entropy = \ # sess.run([self.eval_accuracy, self.eval_cross_entropy], # feed_dict={self.x_labelled_mu: x_valid_mu, # self.x_labelled_lsgms: x_valid_lsgms, # self.y_lab: y_valid}) eval_accuracy = 0 eval_cross_entropy = 0 assert x_valid_mu.shape[ 0] % self.num_lab_batch == 0, '#Valid % #Batches != 0' x_valid_batch_count = x_valid_mu.shape[ 0] // self.num_lab_batch print("valid: self.num_lab_batch: ", self.num_lab_batch) # print("y_") for x_valid_mu_batch, x_valid_lsgms_batch, y_valid_batch in pt.train.feed_numpy( self.num_lab_batch, x_valid_mu, x_valid_lsgms, y_valid): tmp_accuracy, tmp_cross_entropy = sess.run( [self.eval_accuracy, self.eval_cross_entropy], feed_dict={ self.x_labelled_mu: x_valid_mu_batch, self.x_labelled_lsgms: x_valid_lsgms_batch, self.y_lab: y_valid_batch }) eval_accuracy += tmp_accuracy eval_cross_entropy += tmp_cross_entropy eval_accuracy /= x_valid_batch_count eval_cross_entropy /= x_valid_batch_count if eval_accuracy > best_eval_accuracy: best_eval_accuracy = eval_accuracy self.saver.save(sess, self.save_path) stop_counter = 0 print("accuracy update: ", best_eval_accuracy) utils.print_metrics( epoch + 1, ['Training', 'cost', training_cost], ['Validation', 'accuracy', eval_accuracy], ['Validation', 'cross-entropy', eval_cross_entropy]) if stop_counter >= stop_iter: print('Stopping GC training') print('No change in validation accuracy for {} iterations'. format(stop_iter)) print('Best validation accuracy: {}'.format( best_eval_accuracy)) print('Model saved in {}'.format(self.save_path)) break