def plot_all_in_latent(model, mnist): names = ("train", "validation", "test") datasets = (mnist.train, mnist.validation, mnist.test) for name, dataset in zip(names, datasets): plot.plotInLatent(model, dataset.images, dataset.labels, name=name, outdir=PLOTS_DIR)
def train(self, X, max_iter=np.inf, max_epochs=np.inf, cross_validate=True, verbose=True, save=True, outdir="./out", plots_outdir="./png", plot_latent_over_time=False): print("Entering training function!") if save: saver = tf.train.Saver(tf.all_variables()) try: err_train = 0 now = datetime.now().isoformat()[11:] print("------- Training begin: {} -------\n".format(now)) if plot_latent_over_time: # plot latent space over log_BASE time BASE = 2 INCREMENT = 0.5 pow_ = 0 print("Entering training loop:") while True: print("Just got a batch of size "), print(str(self.batch_size)) x, _ = X.train.next_batch(self.batch_size) print("The batch shape is "), print(np.shape(x)) image_to_save = np.reshape(x[0], newshape=[28, 28]) scipy.misc.imsave('current_image.png', image_to_save) break feed_dict = {self.x_in : x, self.dropout_: self.dropout} fetches = [self.x_reconstructed, self.cost, self.global_step, self.train_op] x_reconstructed, cost, i, _ = self.sess.run(fetches, feed_dict) err_train += cost if plot_latent_over_time: while int(round(BASE ** pow_)) == i: plot.exploreLatent(self, nx=30, ny=30, ppf=True, outdir=plots_outdir, name="explore_ppf30_{}".format(pow_)) names = ("train", "validation", "test") datasets = (X.train, X.validation, X.test) for name, dataset in zip(names, datasets): plot.plotInLatent(self, dataset.images, dataset.labels, range_= (-6, 6), title=name, outdir=plots_outdir, name="{}_{}".format(name, pow_)) print("{}^{} = {}".format(BASE, pow_, i)) pow_ += INCREMENT if i % 1000 == 0 and verbose: print("round {} --> avg cost: ".format(i), err_train / i) if i % 2000 == 0 and verbose: # and i >= 10000: # visualize `n` examples of current minibatch inputs + reconstructions plot.plotSubset(self, x, x_reconstructed, n=10, name="train", outdir=plots_outdir) if cross_validate: x, _ = X.validation.next_batch(self.batch_size) feed_dict = {self.x_in: x} fetches = [self.x_reconstructed, self.cost] x_reconstructed, cost = self.sess.run(fetches, feed_dict) print("round {} --> CV cost: ".format(i), cost) plot.plotSubset(self, x, x_reconstructed, n=10, name="cv", outdir=plots_outdir) if i >= max_iter or X.train.epochs_completed >= max_epochs: print("final avg cost (@ step {} = epoch {}): {}".format( i, X.train.epochs_completed, err_train / i)) now = datetime.now().isoformat()[11:] print("------- Training end: {} -------\n".format(now)) if save: outfile = os.path.join(os.path.abspath(outdir), "{}_vae_{}".format( self.datetime, "_".join(map(str, self.architecture)))) saver.save(self.sess, outfile, global_step=self.step) try: self.logger.flush() self.logger.close() except(AttributeError): # not logging continue break except KeyboardInterrupt: print("final avg cost (@ step {} = epoch {}): {}".format( i, X.train.epochs_completed, err_train / i)) now = datetime.now().isoformat()[11:] print("------- Training end: {} -------\n".format(now)) sys.exit(0)
def train(self, X, max_iter=np.inf, max_epochs=np.inf, cross_validate=True, verbose=True, save=True, outdir="./out", plots_outdir="./png", plot_latent_over_time=False, control_plots=False): if save: saver = tf.train.Saver(tf.global_variables()) try: err_train = 0 cost_finalBatch = 0 now = datetime.now().isoformat()[11:] print("------- Training begin: {} -------\n".format(now)) if control_plots and plot_latent_over_time: # plot latent space over log_BASE time BASE = 2 INCREMENT = 0.5 pow_ = 0 nBatches = 0 total_updates = 0 while total_updates < max_iter and X.train.epochs_completed < max_epochs: nBatches += 1 x, _ = X.train.next_batch(self.batch_size) feed_dict = {self.x_in: x, self.dropout_prob: self.dropout} fetches = [ self.x_reconstructed, self.cost, self.total_updates, self.train_op, self.merged_summaries ] x_reconstructed, cost, total_updates, _, summary = self.sesh.run( fetches, feed_dict) err_train += cost cost_finalBatch = cost if control_plots and plot_latent_over_time: while int(round(BASE**pow_)) == total_updates: plot.exploreLatent( self, nx=30, ny=30, ppf=True, outdir=plots_outdir, name="explore_ppf30_{}".format(pow_)) names = ("train", "validation", "test") datasets = (X.train, X.validation, X.test) for name, dataset in zip(names, datasets): plot.plotInLatent(self, dataset.images, dataset.labels, range_=(-6, 6), title=name, outdir=plots_outdir, name="{}_{}".format(name, pow_)) print("{}^{} = {}".format(BASE, pow_, total_updates)) pow_ += INCREMENT if total_updates % 10 == 0: #run_metadata = tf.RunMetadata() #self.logger.add_run_metadata(run_metadata, 'step%03d' % total_updates) self.logger.add_summary(summary, total_updates) if total_updates % 50 == 0 and verbose: print(" iteration {} --> current cost: {}".format( total_updates, cost)) # TO DO: np.dot(a, b), np.linalg.norm(a, axis=1) if total_updates % 500 == 0 and verbose: print("\tMean element-wise row sum of inputs: {}".format( np.average(np.sum(x, 1)))) if total_updates % 1000 == 0 and verbose: print("\titeration {} --> total avg cost: {}".format( total_updates, err_train / total_updates)) if total_updates % 1000 == 0 and verbose: # and total_updates >= 10000: # visualize `n` examples of current minibatch inputs + reconstructions if control_plots: plot.plotSubset(self, x, x_reconstructed, n=10, name="train", outdir=plots_outdir) if cross_validate: x, _ = X.validation.next_batch(self.batch_size) feed_dict = {self.x_in: x} fetches = [self.x_reconstructed, self.cost] x_reconstructed, cost = self.sesh.run( fetches, feed_dict) print(" iteration {} --> CV cost: ".format(i), cost) if control_plots: plot.plotSubset(self, x, x_reconstructed, n=10, name="cv", outdir=plots_outdir) now = datetime.now().isoformat()[11:] print("\n------- Training end: {} -------\n".format(now)) print( " >>> Processed %d epochs in %d batches of size %d, i.e. %d data samples.\n" % (X.train.epochs_completed, nBatches, self.batch_size, nBatches * self.batch_size)) print("Final avg cost: {}".format(err_train / total_updates)) print("Cost of final batch: {}\n".format(cost_finalBatch)) # Test dataset print("\n Testing\n -------") x = X.train.getTestData() feed_dict = {self.x_in: x, self.dropout_prob: 1.0} fetches = [ self.x_reconstructed, self.cost, self.mse_autoencoderTest, self.sqrt_mse_autoencoderTest, self.cosSim_autoencoderTest ] x_reconstructed, cost, mse_autoencoderTest, sqrt_mse_autoencoderTest, cosSim_autoencoderTest = self.sesh.run( fetches, feed_dict) print(" Input:") for row in x[:10]: print(" " + ", ".join([repr(el) for el in row[:20]]) + " ...") print("\n Prediction:") for row in x_reconstructed[:10]: print(" " + ", ".join([repr(el) for el in row[:20]]) + " ...") print("\n Cost: {}".format(cost)) print(" MSE: {}".format(mse_autoencoderTest)) print(" sqrt(MSE): {}".format(sqrt_mse_autoencoderTest)) print(" cosSim: {}".format(cosSim_autoencoderTest)) if save: outfile = os.path.join( os.path.abspath(outdir), "{}_vae_{}".format(self.datetime, "_".join(map(str, self.architecture)))) saver.save(self.sesh, outfile, global_step=self.step) try: self.logger.flush() self.logger.close() except (AttributeError): # not logging pass except (KeyboardInterrupt): print("final avg cost (@ step {} = epoch {}): {}".format( total_updates, X.train.epochs_completed, err_train / total_updates)) now = datetime.now().isoformat()[11:] print("------- Training end: {} -------\n".format(now)) sys.exit(0)