import tensorflow as tf import numpy as np import my_txtutils # these must match what was saved ! ALPHASIZE = my_txtutils.ALPHASIZE NLAYERS = 3 INTERNALSIZE = 512 author = "checkpoints/rnn_train_1526692149-432000000" ncnt = 0 with tf.Session() as sess: new_saver = tf.train.import_meta_graph('checkpoints/rnn_train_1526692149-432000000.meta') new_saver.restore(sess, author) x = my_txtutils.convert_from_alphabet(ord("L")) x = np.array([[x]]) # shape [BATCHSIZE, SEQLEN] with BATCHSIZE=1 and SEQLEN=1 # initial values y = x h = np.zeros([1, INTERNALSIZE * NLAYERS], dtype=np.float32) # [ BATCHSIZE, INTERNALSIZE * NLAYERS] for i in range(1000000000): yo, h = sess.run(['Yo:0', 'H:0'], feed_dict={'X:0': y, 'pkeep:0': 1., 'Hin:0': h, 'batchsize:0': 1}) # If sampling is be done from the topn most likely characters, the generated text # is more credible and more "english". If topn is not set, it defaults to the full # distribution (ALPHASIZE) # Recommended: topn = 10 for intermediate checkpoints, topn=2 or 3 for fully trained checkpoints c = my_txtutils.sample_from_probabilities(yo, topn=2)
VALI_SEQLEN = 1*1024 # Sequence length for validation. State will be wrong at the start of each sequence. bsize = len(valitext) // VALI_SEQLEN txt.print_validation_header(len(codetext), bookranges) vali_x, vali_y, _ = next(txt.rnn_minibatch_sequencer(valitext, bsize, VALI_SEQLEN, 1)) # all data in 1 batch vali_nullstate = np.zeros([bsize, INTERNALSIZE*NLAYERS]) feed_dict = {X: vali_x, Y_: vali_y, Hin: vali_nullstate, pkeep: 1.0, # no dropout for validation batchsize: bsize} ls, acc, smm = sess.run([batchloss, accuracy, summaries], feed_dict=feed_dict) txt.print_validation_stats(ls, acc) # save validation data for Tensorboard validation_writer.add_summary(smm, step) # display a short text generated with the current weights and biases (every 150 batches) if step // 3 % _50_BATCHES == 0: txt.print_text_generation_header() ry = np.array([[txt.convert_from_alphabet(ord("K"))]]) rh = np.zeros([1, INTERNALSIZE * NLAYERS]) for k in range(1000): ryo, rh = sess.run([Yo, H], feed_dict={X: ry, pkeep: 1.0, Hin: rh, batchsize: 1}) rc = txt.sample_from_probabilities(ryo, topn=10 if epoch <= 1 else 2) print(chr(txt.convert_to_alphabet(rc)), end="") ry = np.array([[rc]]) txt.print_text_generation_footer() # save a checkpoint (every 500 batches) if step // 10 % _50_BATCHES == 0: saved_file = saver.save(sess, 'checkpoints/rnn_train_' + timestamp, global_step=step) print("Saved file: " + saved_file) # display progress bar progress.step(reset=step % _50_BATCHES == 0)
def fit(self, data, epochs=1000, displayFreq=50, genFreq=150, saveFreq=5000, verbosity=2): progress = txt.Progress(displayFreq, size=111+2, msg="Training on next "+str(displayFreq)+" batches") tfStuff = self.tfStuff valitext = data.valitext # todo: check if batchSize != data.batchSize or if seqLen != data.seqLen (if so, I think we need to raise an exception?) firstEpoch = self.curEpoch lastEpoch = firstEpoch + epochs isFirstStepInThisFitCall = True try: with tfStuff.graph.as_default(): #with tf.name_scope(self.scopeName): with tf.variable_scope(self.fullName, reuse=tf.AUTO_REUSE): sess = tfStuff.sess # training loop for x, y_, epoch, batch in txt.rnn_minibatch_sequencer(data.codetext, self.batchSize, self.seqLen, nb_epochs=epochs, startBatch=self.curBatch, startEpoch=self.curEpoch): nSteps = self.step // (self.batchSize*self.seqLen) # train on one minibatch feed_dict = {tfStuff.X: x, tfStuff.Y_: y_, tfStuff.Hin: tfStuff.istate, tfStuff.lr: self.learningRate, tfStuff.pkeep: self.dropoutPkeep, tfStuff.batchsize: self.batchSize} _, y, ostate = sess.run([tfStuff.train_step, tfStuff.Y, tfStuff.H], feed_dict=feed_dict) # log training data for Tensorboard display a mini-batch of sequences (every 50 batches) if nSteps % displayFreq == 0 or isFirstStepInThisFitCall: feed_dict = {tfStuff.X: x, tfStuff.Y_: y_, tfStuff.Hin: tfStuff.istate, tfStuff.pkeep: 1.0, tfStuff.batchsize: self.batchSize} # no dropout for validation y, l, bl, acc, smm = sess.run([tfStuff.Y, tfStuff.seqloss, tfStuff.batchloss, tfStuff.accuracy, tfStuff.summaries], feed_dict=feed_dict) txt.print_learning_learned_comparison(x, y, l, data.bookranges, bl, acc, data.epoch_size, self.step, epoch, lastEpoch, verbosity=verbosity) self.tbStuff.summary_writer.add_summary(smm, self.step) # run a validation step every 50 batches # The validation text should be a single sequence but that's too slow (1s per 1024 chars!), # so we cut it up and batch the pieces (slightly inaccurate) # tested: validating with 5K sequences instead of 1K is only slightly more accurate, but a lot slower. if (nSteps % displayFreq == 0 or isFirstStepInThisFitCall) and len(data.valitext) > 0: VALI_seqLen = 1*1024 # Sequence length for validation. State will be wrong at the start of each sequence. bsize = len(data.valitext) // VALI_seqLen if verbosity >= 1: txt.print_validation_header(len(data.codetext), data.bookranges) vali_x, vali_y, _, _ = next(txt.rnn_minibatch_sequencer(data.valitext, bsize, VALI_seqLen, 1)) # all data in 1 batch vali_nullstate = np.zeros([bsize, self.internalSize * self.nLayers]) feed_dict = {tfStuff.X: vali_x, tfStuff.Y_: vali_y, tfStuff.Hin: vali_nullstate, tfStuff.pkeep: 1.0, # no dropout for validation tfStuff.batchsize: bsize} ls, acc, smm = sess.run([tfStuff.batchloss, tfStuff.accuracy, tfStuff.summaries], feed_dict=feed_dict) if verbosity >= 1: txt.print_validation_stats(ls, acc) # save validation data for Tensorboard self.tbStuff.validation_writer.add_summary(smm, self.step) # display a short text generated with the current weights and biases (every 150 batches) if nSteps % genFreq == 0 or isFirstStepInThisFitCall: txt.print_text_generation_header() ry = np.array([[txt.convert_from_alphabet(ord("K"))]]) rh = np.zeros([1, self.internalSize * self.nLayers]) for k in range(1000): ryo, rh = sess.run([tfStuff.Yo, tfStuff.H], feed_dict={tfStuff.X: ry, tfStuff.pkeep: 1.0, tfStuff.Hin: rh, tfStuff.batchsize: 1}) rc = txt.sample_from_probabilities(ryo, topn=10 if epoch <= 1 else 2) print(chr(txt.convert_to_alphabet(rc)), end="") ry = np.array([[rc]]) txt.print_text_generation_footer() if isFirstStepInThisFitCall: for i in range(nSteps % displayFreq): progress.step() isFirstStepInThisFitCall = False # save a checkpoint (every 500 batches) if nSteps % saveFreq == 0: self.save(alreadyInGraph=True) # display progress bar progress.step(reset=nSteps % displayFreq == 0) # loop state around tfStuff.istate = ostate self.step += self.batchSize * self.seqLen self.curEpoch = epoch self.curBatch = batch except KeyboardInterrupt as e: print("\npressed ctrl-c, saving") self.save()