# training loop
for x, y_, epoch in txt.rnn_minibatch_sequencer(codetext, BATCHSIZE, SEQLEN, nb_epochs=20):

    # train on one minibatch
    feed_dict = {X: x, Y_: y_, Hin: istate, lr: learning_rate, pkeep: dropout_pkeep, batchsize: BATCHSIZE}
    _, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)

    # save training data for Tensorboard
    summary_writer.add_summary(smm, step)

    # display a visual validation of progress (every 50 batches)
    if step % _50_BATCHES == 0:
        feed_dict = {X: x, Y_: y_, Hin: istate, pkeep: 1.0, batchsize: BATCHSIZE}  # no dropout for validation
        y, l, bl, acc = sess.run([Y, seqloss, batchloss, accuracy], feed_dict=feed_dict)
        txt.print_learning_learned_comparison(x, y, l, bookranges, bl, acc, epoch_size, step, epoch)

    # run a validation step every 50 batches
    # The validation text should be a single sequence but that's too slow (1s per 1024 chars!),
    # so we cut it up and batch the pieces (slightly inaccurate)
    # tested: validating with 5K sequences instead of 1K is only slightly more accurate, but a lot slower.
    if step % _50_BATCHES == 0 and len(valitext) > 0:
        VALI_SEQLEN = 1*1024  # Sequence length for validation. State will be wrong at the start of each sequence.
        bsize = len(valitext) // VALI_SEQLEN
        txt.print_validation_header(len(codetext), bookranges)
        vali_x, vali_y, _ = next(txt.rnn_minibatch_sequencer(valitext, bsize, VALI_SEQLEN, 1))  # all data in 1 batch
        vali_nullstate = np.zeros([bsize, INTERNALSIZE*NLAYERS])
        feed_dict = {X: vali_x, Y_: vali_y, Hin: vali_nullstate, pkeep: 1.0,  # no dropout for validation
                     batchsize: bsize}
        ls, acc, smm = sess.run([batchloss, accuracy, summaries], feed_dict=feed_dict)
        txt.print_validation_stats(ls, acc)
def main(_):

    # load data, either shakespeare, or the Python source of Tensorflow itself
    shakedir = FLAGS.text_dir
    # shakedir = "../tensorflow/**/*.py"
    codetext, valitext, bookranges = txt.read_data_files(shakedir,
                                                         validation=True)

    # display some stats on the data
    epoch_size = len(codetext) // (FLAGS.train_batch_size * FLAGS.seqlen)
    txt.print_data_stats(len(codetext), len(valitext), epoch_size)

    #
    # the model (see FAQ in README.md)
    #
    lr = tf.placeholder(tf.float32, name='lr')  # learning rate
    pkeep = tf.placeholder(tf.float32, name='pkeep')  # dropout parameter
    batchsize = tf.placeholder(tf.int32, name='batchsize')

    # inputs
    X = tf.placeholder(tf.uint8, [None, None],
                       name='X')  # [ BATCHSIZE, FLAGS.seqlen ]
    Xo = tf.one_hot(X, ALPHASIZE, 1.0,
                    0.0)  # [ BATCHSIZE, FLAGS.seqlen, ALPHASIZE ]
    # expected outputs = same sequence shifted by 1 since we are trying to predict the next character
    Y_ = tf.placeholder(tf.uint8, [None, None],
                        name='Y_')  # [ BATCHSIZE, FLAGS.seqlen ]
    Yo_ = tf.one_hot(Y_, ALPHASIZE, 1.0,
                     0.0)  # [ BATCHSIZE, FLAGS.seqlen, ALPHASIZE ]
    # input state
    Hin = tf.placeholder(tf.float32, [None, INTERNALSIZE * NLAYERS],
                         name='Hin')  # [ BATCHSIZE, INTERNALSIZE * NLAYERS]

    # using a NLAYERS=3 layers of GRU cells, unrolled FLAGS.seqlen=30 times
    # dynamic_rnn infers FLAGS.seqlen from the size of the inputs Xo

    onecell = rnn.GRUCell(INTERNALSIZE)
    dropcell = rnn.DropoutWrapper(onecell, input_keep_prob=pkeep)
    multicell = rnn.MultiRNNCell([dropcell] * NLAYERS, state_is_tuple=False)
    multicell = rnn.DropoutWrapper(multicell, output_keep_prob=pkeep)
    Yr, H = tf.nn.dynamic_rnn(multicell,
                              Xo,
                              dtype=tf.float32,
                              initial_state=Hin)
    # Yr: [ BATCHSIZE, FLAGS.seqlen, INTERNALSIZE ]
    # H:  [ BATCHSIZE, INTERNALSIZE*NLAYERS ] # this is the last state in the sequence

    H = tf.identity(H, name='H')  # just to give it a name

    # Softmax layer implementation:
    # Flatten the first two dimension of the output [ BATCHSIZE, FLAGS.seqlen, ALPHASIZE ] => [ BATCHSIZE x FLAGS.seqlen, ALPHASIZE ]
    # then apply softmax readout layer. This way, the weights and biases are shared across unrolled time steps.
    # From the readout point of view, a value coming from a cell or a minibatch is the same thing

    Yflat = tf.reshape(
        Yr, [-1, INTERNALSIZE])  # [ BATCHSIZE x FLAGS.seqlen, INTERNALSIZE ]
    Ylogits = layers.linear(
        Yflat, ALPHASIZE)  # [ BATCHSIZE x FLAGS.seqlen, ALPHASIZE ]
    Yflat_ = tf.reshape(
        Yo_, [-1, ALPHASIZE])  # [ BATCHSIZE x FLAGS.seqlen, ALPHASIZE ]
    loss = tf.nn.softmax_cross_entropy_with_logits(
        logits=Ylogits, labels=Yflat_)  # [ BATCHSIZE x FLAGS.seqlen ]
    loss = tf.reshape(loss, [batchsize, -1])  # [ BATCHSIZE, FLAGS.seqlen ]
    Yo = tf.nn.softmax(Ylogits,
                       name='Yo')  # [ BATCHSIZE x FLAGS.seqlen, ALPHASIZE ]
    Y = tf.argmax(Yo, 1)  # [ BATCHSIZE x FLAGS.seqlen ]
    Y = tf.reshape(Y, [batchsize, -1], name="Y")  # [ BATCHSIZE, FLAGS.seqlen ]
    train_step = tf.train.AdamOptimizer(lr).minimize(loss)

    # stats for display
    seqloss = tf.reduce_mean(loss, 1)
    batchloss = tf.reduce_mean(seqloss)
    accuracy = tf.reduce_mean(
        tf.cast(tf.equal(Y_, tf.cast(Y, tf.uint8)), tf.float32))
    loss_summary = tf.summary.scalar("batch_loss", batchloss)
    acc_summary = tf.summary.scalar("batch_accuracy", accuracy)
    summaries = tf.summary.merge([loss_summary, acc_summary])

    # Init Tensorboard stuff. This will save Tensorboard information into a different
    # folder at each run named 'log/<timestamp>/'. Two sets of data are saved so that
    # you can compare training and validation curves visually in Tensorboard.
    timestamp = str(math.trunc(time.time()))
    summary_writer = tf.summary.FileWriter(
        os.path.join(FLAGS.summaries_dir, timestamp + "-training"))
    validation_writer = tf.summary.FileWriter(
        os.path.join(FLAGS.summaries_dir, timestamp + "-validation"))

    # Init for saving models. They will be saved into a directory named 'checkpoints'.
    # Only the last checkpoint is kept.
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.mkdir(FLAGS.checkpoint_dir)
    saver = tf.train.Saver(max_to_keep=1)

    # for display: init the progress bar
    DISPLAY_FREQ = 50
    _50_BATCHES = DISPLAY_FREQ * FLAGS.train_batch_size * FLAGS.seqlen
    progress = txt.Progress(DISPLAY_FREQ,
                            size=111 + 2,
                            msg="Training on next " + str(DISPLAY_FREQ) +
                            " batches")

    # init
    istate = np.zeros([FLAGS.train_batch_size,
                       INTERNALSIZE * NLAYERS])  # initial zero input state
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    step = 0

    # training loop
    for x, y_, epoch in txt.rnn_minibatch_sequencer(codetext,
                                                    FLAGS.train_batch_size,
                                                    FLAGS.seqlen,
                                                    nb_epochs=1000):

        # train on one minibatch
        feed_dict = {
            X: x,
            Y_: y_,
            Hin: istate,
            lr: FLAGS.learning_rate,
            pkeep: FLAGS.dropout_pkeep,
            batchsize: FLAGS.train_batch_size
        }
        _, y, ostate, smm = sess.run([train_step, Y, H, summaries],
                                     feed_dict=feed_dict)

        # save training data for Tensorboard
        summary_writer.add_summary(smm, step)

        # display a visual validation of progress (every 50 batches)
        if step % _50_BATCHES == 0:
            feed_dict = {
                X: x,
                Y_: y_,
                Hin: istate,
                pkeep: 1.0,
                batchsize: FLAGS.train_batch_size
            }  # no dropout for validation
            y, l, bl, acc = sess.run([Y, seqloss, batchloss, accuracy],
                                     feed_dict=feed_dict)
            txt.print_learning_learned_comparison(x, y, l, bookranges, bl, acc,
                                                  epoch_size, step, epoch)

        # run a validation step every 50 batches
        # The validation text should be a single sequence but that's too slow (1s per 1024 chars!),
        # so we cut it up and batch the pieces (slightly inaccurate)
        # tested: validating with 5K sequences instead of 1K is only slightly more accurate, but a lot slower.
        if step % _50_BATCHES == 0 and len(valitext) > 0:
            VALI_SEQLEN = 1 * 1024  # Sequence length for validation. State will be wrong at the start of each sequence.
            bsize = len(valitext) // VALI_SEQLEN
            txt.print_validation_header(len(codetext), bookranges)
            vali_x, vali_y, _ = next(
                txt.rnn_minibatch_sequencer(valitext, bsize, VALI_SEQLEN,
                                            1))  # all data in 1 batch
            vali_nullstate = np.zeros([bsize, INTERNALSIZE * NLAYERS])
            feed_dict = {
                X: vali_x,
                Y_: vali_y,
                Hin: vali_nullstate,
                pkeep: 1.0,  # no dropout for validation
                batchsize: bsize
            }
            ls, acc, smm = sess.run([batchloss, accuracy, summaries],
                                    feed_dict=feed_dict)
            txt.print_validation_stats(ls, acc)
            # save validation data for Tensorboard
            validation_writer.add_summary(smm, step)

        # display a short text generated with the current weights and biases (every 150 batches)
        if step // 3 % _50_BATCHES == 0:
            txt.print_text_generation_header()
            ry = np.array([[txt.convert_from_alphabet(ord("K"))]])
            rh = np.zeros([1, INTERNALSIZE * NLAYERS])
            for k in range(1000):
                ryo, rh = sess.run([Yo, H],
                                   feed_dict={
                                       X: ry,
                                       pkeep: 1.0,
                                       Hin: rh,
                                       batchsize: 1
                                   })
                rc = txt.sample_from_probabilities(
                    ryo, topn=10 if epoch <= 1 else 2)
                print(chr(txt.convert_to_alphabet(rc)), end="")
                ry = np.array([[rc]])
            txt.print_text_generation_footer()

        # save a checkpoint (every 500 batches)
        if step // 10 % _50_BATCHES == 0:
            saver.save(sess,
                       FLAGS.checkpoint_dir + '/rnn_train_' + timestamp,
                       global_step=step)

        # display progress bar
        progress.step(reset=step % _50_BATCHES == 0)

        # loop state around
        istate = ostate
        step += FLAGS.train_batch_size * FLAGS.seqlen
        batchsize: BATCH_SIZE
    }
    _, y, ostate = sess.run([train_step, Y, H], feed_dict=feed_dict)

    # Store training TensorBoard data
    if step % _50_BATCHES == 0:
        feed_dict = {
            X: x,
            Y_: y_,
            Hin: istate,
            pkeep: 1.0,
            batchsize: BATCH_SIZE
        }
        y, l, bl, acc, smm = sess.run(
            [Y, seqloss, batchloss, accuracy, summaries], feed_dict=feed_dict)
        txt.print_learning_learned_comparison(x, y, l, scriptranges, bl, acc,
                                              size_of_epoch, step, epoch)
        summary_writer.add_summary(smm, step)

    # Validation is batched in order to run quicker
    if step % _50_BATCHES == 0 and len(validtext) > 0:
        VALI_SEQ_LEN = 1 * 1024
        bsize = len(validtext) // VALI_SEQ_LEN
        txt.print_validation_header(len(traintext), scriptranges)
        vali_x, vali_y, _ = next(
            txt.rnn_minibatch_sequencer(validtext, bsize, VALI_SEQ_LEN, 1))
        vali_nullstate = np.zeros([bsize, NUM_OF_GRUS * NUM_LAYERS])
        feed_dict = {
            X: vali_x,
            Y_: vali_y,
            Hin: vali_nullstate,
            pkeep: 1.0,
Ejemplo n.º 4
0
	def fit(self, data, epochs=1000, displayFreq=50, genFreq=150, saveFreq=5000, verbosity=2):
		progress = txt.Progress(displayFreq, size=111+2, msg="Training on next "+str(displayFreq)+" batches")
		tfStuff = self.tfStuff
		valitext = data.valitext
		# todo: check if batchSize != data.batchSize or if seqLen != data.seqLen (if so, I think we need to raise an exception?)
		firstEpoch = self.curEpoch
		lastEpoch = firstEpoch + epochs
		isFirstStepInThisFitCall = True
		try:
			with tfStuff.graph.as_default():
				#with tf.name_scope(self.scopeName):
				with tf.variable_scope(self.fullName, reuse=tf.AUTO_REUSE):
					
					sess = tfStuff.sess
					
					# training loop
					for x, y_, epoch, batch in txt.rnn_minibatch_sequencer(data.codetext, self.batchSize, self.seqLen, nb_epochs=epochs, startBatch=self.curBatch, startEpoch=self.curEpoch):
						
						nSteps = self.step // (self.batchSize*self.seqLen)
						# train on one minibatch
						feed_dict = {tfStuff.X: x, tfStuff.Y_: y_, tfStuff.Hin: tfStuff.istate, tfStuff.lr: self.learningRate, tfStuff.pkeep: self.dropoutPkeep, tfStuff.batchsize: self.batchSize}
						_, y, ostate = sess.run([tfStuff.train_step, tfStuff.Y, tfStuff.H], feed_dict=feed_dict)

						# log training data for Tensorboard display a mini-batch of sequences (every 50 batches)
						if nSteps % displayFreq == 0 or isFirstStepInThisFitCall:
							feed_dict = {tfStuff.X: x, tfStuff.Y_: y_, tfStuff.Hin: tfStuff.istate, tfStuff.pkeep: 1.0, tfStuff.batchsize: self.batchSize}  # no dropout for validation
							y, l, bl, acc, smm = sess.run([tfStuff.Y, tfStuff.seqloss, tfStuff.batchloss, tfStuff.accuracy, tfStuff.summaries], feed_dict=feed_dict)
							txt.print_learning_learned_comparison(x, y, l, data.bookranges, bl, acc, data.epoch_size, self.step, epoch, lastEpoch, verbosity=verbosity)
							self.tbStuff.summary_writer.add_summary(smm, self.step)
						# run a validation step every 50 batches
						# The validation text should be a single sequence but that's too slow (1s per 1024 chars!),
						# so we cut it up and batch the pieces (slightly inaccurate)
						# tested: validating with 5K sequences instead of 1K is only slightly more accurate, but a lot slower.
						
						if (nSteps % displayFreq == 0 or isFirstStepInThisFitCall) and len(data.valitext) > 0:
							VALI_seqLen = 1*1024  # Sequence length for validation. State will be wrong at the start of each sequence.
							bsize = len(data.valitext) // VALI_seqLen
							if verbosity >= 1: txt.print_validation_header(len(data.codetext), data.bookranges)
							vali_x, vali_y, _, _ = next(txt.rnn_minibatch_sequencer(data.valitext, bsize, VALI_seqLen, 1))  # all data in 1 batch
							vali_nullstate = np.zeros([bsize, self.internalSize * self.nLayers])
							feed_dict = {tfStuff.X: vali_x, tfStuff.Y_: vali_y, tfStuff.Hin: vali_nullstate, tfStuff.pkeep: 1.0,  # no dropout for validation
										 tfStuff.batchsize: bsize}
							ls, acc, smm = sess.run([tfStuff.batchloss, tfStuff.accuracy, tfStuff.summaries], feed_dict=feed_dict)
							if verbosity >= 1: txt.print_validation_stats(ls, acc)
							# save validation data for Tensorboard
							self.tbStuff.validation_writer.add_summary(smm, self.step)

						# display a short text generated with the current weights and biases (every 150 batches)
						if nSteps % genFreq == 0 or isFirstStepInThisFitCall:
							txt.print_text_generation_header()
							ry = np.array([[txt.convert_from_alphabet(ord("K"))]])
							rh = np.zeros([1, self.internalSize * self.nLayers])
							for k in range(1000):
								ryo, rh = sess.run([tfStuff.Yo, tfStuff.H], feed_dict={tfStuff.X: ry, tfStuff.pkeep: 1.0, tfStuff.Hin: rh, tfStuff.batchsize: 1})
								rc = txt.sample_from_probabilities(ryo, topn=10 if epoch <= 1 else 2)
								print(chr(txt.convert_to_alphabet(rc)), end="")
								ry = np.array([[rc]])
							txt.print_text_generation_footer()
						if isFirstStepInThisFitCall:
							for i in range(nSteps % displayFreq):
								progress.step()
							isFirstStepInThisFitCall = False
						# save a checkpoint (every 500 batches)
						if nSteps % saveFreq == 0:
							self.save(alreadyInGraph=True)

						# display progress bar
						progress.step(reset=nSteps % displayFreq == 0)

						# loop state around
						tfStuff.istate = ostate
						self.step += self.batchSize * self.seqLen
						self.curEpoch = epoch
						self.curBatch = batch
		except KeyboardInterrupt as e:
			print("\npressed ctrl-c, saving")
			self.save()