def main(): parser = argparse.ArgumentParser() parser.add_argument('--model_name', type=str, required=True) parser.add_argument('--vocabulary_file', type=str, required=True) parser.add_argument('--output_file', type=str, required=True) parser.add_argument('--seed', type=str, default="Once upon a time, ") parser.add_argument('--sample_length', type=int, default=1500) parser.add_argument('--log_frequency', type=int, default=100) args = parser.parse_args() model_name = args.model_name vocabulary_file = args.vocabulary_file output_file = args.output_file seed = args.seed.decode('utf-8') sample_length = args.sample_length log_frequency = args.log_frequency model = Model(model_name) model.restore() classifier = model.get_classifier() vocabulary = Vocabulary() vocabulary.retrieve(vocabulary_file) sample_file = codecs.open(output_file, 'w', 'utf_8') stack = deque([]) for i in range(0, model.sequence_length - len(seed)): stack.append(u' ') for char in seed: if char not in vocabulary.vocabulary: print char,"is not in vocabulary file" char = u' ' stack.append(char) sample_file.write(char) with tf.Session() as sess: tf.global_variables_initializer().run() saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(model_name) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) for i in range(0, sample_length): vector = [] for char in stack: vector.append(vocabulary.binary_vocabulary[char]) vector = np.array([vector]) prediction = sess.run(classifier, feed_dict={model.x: vector}) predicted_char = vocabulary.char_lookup[np.argmax(prediction)] stack.popleft() stack.append(predicted_char) sample_file.write(predicted_char) if i % log_frequency == 0: print "Progress: {}%".format((i * 100) / sample_length) sample_file.close() print "Sample saved in {}".format(output_file)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--training_file', type=str, required=True) parser.add_argument('--vocabulary_file', type=str, required=True) parser.add_argument('--model_name', type=str, required=True) parser.add_argument('--epoch', type=int, default=200) parser.add_argument('--batch_size', type=int, default=50) parser.add_argument('--sequence_length', type=int, default=50) parser.add_argument('--log_frequency', type=int, default=100) parser.add_argument('--learning_rate', type=int, default=0.002) parser.add_argument('--units_number', type=int, default=128) parser.add_argument('--layers_number', type=int, default=2) args = parser.parse_args() training_file = args.training_file vocabulary_file = args.vocabulary_file model_name = args.model_name epoch = args.epoch batch_size = args.batch_size sequence_length = args.sequence_length log_frequency = args.log_frequency learning_rate = args.learning_rate batch = Batch(training_file, vocabulary_file, batch_size, sequence_length) input_number = batch.vocabulary.size classes_number = batch.vocabulary.size units_number = args.units_number layers_number = args.layers_number print "Start training with epoch: {}, batch_size: {}, log_frequency: {}," \ "learning_rate: {}".format(epoch, batch_size, log_frequency, learning_rate) if not os.path.exists(model_name): os.makedirs(model_name) model = Model(model_name) model.build(input_number, sequence_length, layers_number, units_number, classes_number) classifier = model.get_classifier() cost = tf.reduce_mean(tf.square(classifier - model.y)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) expected_prediction = tf.equal(tf.argmax(classifier, 1), tf.argmax(model.y, 1)) accuracy = tf.reduce_mean(tf.cast(expected_prediction, tf.float32)) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) iteration = 0 while batch.dataset_full_passes < epoch: iteration += 1 batch_x, batch_y = batch.get_next_batch() batch_x = batch_x.reshape((batch_size, sequence_length, input_number)) sess.run(optimizer, feed_dict={model.x: batch_x, model.y: batch_y}) if iteration % log_frequency == 0: acc = sess.run(accuracy, feed_dict={model.x: batch_x, model.y: batch_y}) loss = sess.run(cost, feed_dict={model.x: batch_x, model.y: batch_y}) print("Iteration {}, batch loss: {:.6f}, training accuracy: {:.5f}".format(iteration * batch_size, loss, acc)) batch.clean() print("Optimization done") saver = tf.train.Saver(tf.global_variables()) checkpoint_path = "{}/{}.ckpt".format(model_name, model_name) saver.save(sess, checkpoint_path, global_step=iteration * batch_size) print("Model saved in {}".format(model_name))