def generate_train_valid_datasets(): if FLAGS.read_cache: train_ds = read_cache_dataset(FLAGS.train_cache) valid_ds = read_cache_dataset(FLAGS.valid_cache) if train_ds.event.shape[1]!=FLAGS.sequence_len: raise ValueError("The event length of training cached dataset %d is inconsistent with given sequene_len %d"%(train_ds.event.shape()[1],FLAGS.sequence_len)) if valid_ds.event.shape[1]!=FLAGS.sequence_len: raise ValueError("The event length of training cached dataset %d is inconsistent with given sequene_len %d"%(valid_ds.event.shape()[1],FLAGS.sequence_len)) return train_ds,valid_ds sys.stdout.write("Begin reading training dataset.\n") train_ds = read_tfrecord(FLAGS.data_dir, FLAGS.tfrecord, FLAGS.train_cache, FLAGS.sequence_len, k_mer=FLAGS.k_mer, max_segments_num=FLAGS.segments_num) sys.stdout.write("Begin reading validation dataset.\n") if FLAGS.validation is not None: valid_ds = read_tfrecord(FLAGS.data_dir, FLAGS.validation, FLAGS.valid_cache, FLAGS.sequence_len, k_mer=FLAGS.k_mer, max_segments_num=FLAGS.segments_num) else: valid_ds = train_ds return train_ds,valid_ds
def train(): training = tf.placeholder(tf.bool) global_step = tf.get_variable('global_step', trainable=False, shape=(), dtype=tf.int32, initializer=tf.zeros_initializer()) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.sequence_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) y_indexs = tf.placeholder(tf.int64) y_values = tf.placeholder(tf.int32) y_shape = tf.placeholder(tf.int64) y = tf.SparseTensor(y_indexs, y_values, y_shape) default_config = os.path.join(FLAGS.log_dir, FLAGS.model_name, 'model.json') if FLAGS.retrain: if os.path.isfile(default_config): config_file = default_config else: raise ValueError( "Model Json file has not been found in model log directory") else: config_file = FLAGS.configure config = model.read_config(config_file) logits, ratio = model.inference(x, seq_length, training, FLAGS.sequence_len, configure=config) ctc_loss = model.loss(logits, seq_length, y) opt = model.train_opt(FLAGS.step_rate, FLAGS.max_steps, global_step=global_step) step = opt.minimize(ctc_loss, global_step=global_step) error = model.prediction(logits, seq_length, y) init = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) if FLAGS.retrain == False: sess.run(init) print("Model init finished, begin loading data. \n") else: saver.restore( sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name)) print("Model loaded finished, begin loading data. \n") summary_writer = tf.summary.FileWriter( FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph) model.save_model(default_config, config) train_ds = read_tfrecord(FLAGS.data_dir, FLAGS.tfrecord, FLAGS.cache_file, FLAGS.sequence_len, k_mer=FLAGS.k_mer, max_segments_num=FLAGS.segments_num) start = time.time() for i in range(FLAGS.max_steps): batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size) indxs, values, shape = batch_y feed_dict = { x: batch_x, seq_length: seq_len / ratio, y_indexs: indxs, y_values: values, y_shape: shape, training: True } loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict) if i % 10 == 0: global_step_val = tf.train.global_step(sess, global_step) valid_x, valid_len, valid_y = train_ds.next_batch(FLAGS.batch_size) indxs, values, shape = valid_y feed_dict = { x: valid_x, seq_length: valid_len / ratio, y_indexs: indxs, y_values: values, y_shape: shape, training: True } error_val = sess.run(error, feed_dict=feed_dict) end = time.time() print( "Step %d/%d Epoch %d, batch number %d, loss: %5.3f edit_distance: %5.3f Elapsed Time/step: %5.3f" \ % (i, FLAGS.max_steps, train_ds.epochs_completed, train_ds.index_in_epoch, loss_val, error_val, (end - start) / (i + 1))) saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/model.ckpt', global_step=global_step_val) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step=global_step_val) summary_writer.flush() global_step_val = tf.train.global_step(sess, global_step) print("Model %s saved." % (FLAGS.log_dir + FLAGS.model_name)) print("Reads number %d" % (train_ds.reads_n)) saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/final.ckpt', global_step=global_step_val)