예제 #1
0
def generate_train_valid_datasets():
    if FLAGS.read_cache:
        train_ds = read_cache_dataset(FLAGS.train_cache)
        valid_ds = read_cache_dataset(FLAGS.valid_cache)
        if train_ds.event.shape[1]!=FLAGS.sequence_len:
            raise ValueError("The event length of training cached dataset %d is inconsistent with given sequene_len %d"%(train_ds.event.shape()[1],FLAGS.sequence_len))
        if valid_ds.event.shape[1]!=FLAGS.sequence_len:
            raise ValueError("The event length of training cached dataset %d is inconsistent with given sequene_len %d"%(valid_ds.event.shape()[1],FLAGS.sequence_len))
        return train_ds,valid_ds
    sys.stdout.write("Begin reading training dataset.\n")
    train_ds = read_tfrecord(FLAGS.data_dir, 
                             FLAGS.tfrecord, 
                             FLAGS.train_cache,
                             FLAGS.sequence_len, 
                             k_mer=FLAGS.k_mer,
                             max_segments_num=FLAGS.segments_num)
    sys.stdout.write("Begin reading validation dataset.\n")
    if FLAGS.validation is not None:
        valid_ds = read_tfrecord(FLAGS.data_dir, 
                                 FLAGS.validation,
                                 FLAGS.valid_cache,
                                 FLAGS.sequence_len, 
                                 k_mer=FLAGS.k_mer,
                                 max_segments_num=FLAGS.segments_num)
    else:
        valid_ds = train_ds
    return train_ds,valid_ds
예제 #2
0
def train():
    training = tf.placeholder(tf.bool)
    global_step = tf.get_variable('global_step',
                                  trainable=False,
                                  shape=(),
                                  dtype=tf.int32,
                                  initializer=tf.zeros_initializer())
    x = tf.placeholder(tf.float32,
                       shape=[FLAGS.batch_size, FLAGS.sequence_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    y_indexs = tf.placeholder(tf.int64)
    y_values = tf.placeholder(tf.int32)
    y_shape = tf.placeholder(tf.int64)
    y = tf.SparseTensor(y_indexs, y_values, y_shape)
    default_config = os.path.join(FLAGS.log_dir, FLAGS.model_name,
                                  'model.json')
    if FLAGS.retrain:
        if os.path.isfile(default_config):
            config_file = default_config
        else:
            raise ValueError(
                "Model Json file has not been found in model log directory")
    else:
        config_file = FLAGS.configure
    config = model.read_config(config_file)
    logits, ratio = model.inference(x,
                                    seq_length,
                                    training,
                                    FLAGS.sequence_len,
                                    configure=config)
    ctc_loss = model.loss(logits, seq_length, y)
    opt = model.train_opt(FLAGS.step_rate,
                          FLAGS.max_steps,
                          global_step=global_step)
    step = opt.minimize(ctc_loss, global_step=global_step)
    error = model.prediction(logits, seq_length, y)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    if FLAGS.retrain == False:
        sess.run(init)
        print("Model init finished, begin loading data. \n")
    else:
        saver.restore(
            sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name))
        print("Model loaded finished, begin loading data. \n")
    summary_writer = tf.summary.FileWriter(
        FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph)
    model.save_model(default_config, config)
    train_ds = read_tfrecord(FLAGS.data_dir,
                             FLAGS.tfrecord,
                             FLAGS.cache_file,
                             FLAGS.sequence_len,
                             k_mer=FLAGS.k_mer,
                             max_segments_num=FLAGS.segments_num)
    start = time.time()
    for i in range(FLAGS.max_steps):
        batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size)
        indxs, values, shape = batch_y
        feed_dict = {
            x: batch_x,
            seq_length: seq_len / ratio,
            y_indexs: indxs,
            y_values: values,
            y_shape: shape,
            training: True
        }
        loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict)
        if i % 10 == 0:
            global_step_val = tf.train.global_step(sess, global_step)
            valid_x, valid_len, valid_y = train_ds.next_batch(FLAGS.batch_size)
            indxs, values, shape = valid_y
            feed_dict = {
                x: valid_x,
                seq_length: valid_len / ratio,
                y_indexs: indxs,
                y_values: values,
                y_shape: shape,
                training: True
            }
            error_val = sess.run(error, feed_dict=feed_dict)
            end = time.time()
            print(
            "Step %d/%d Epoch %d, batch number %d, loss: %5.3f edit_distance: %5.3f Elapsed Time/step: %5.3f" \
            % (i, FLAGS.max_steps, train_ds.epochs_completed,
               train_ds.index_in_epoch, loss_val, error_val,
               (end - start) / (i + 1)))
            saver.save(sess,
                       FLAGS.log_dir + FLAGS.model_name + '/model.ckpt',
                       global_step=global_step_val)
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str,
                                       global_step=global_step_val)
            summary_writer.flush()
    global_step_val = tf.train.global_step(sess, global_step)
    print("Model %s saved." % (FLAGS.log_dir + FLAGS.model_name))
    print("Reads number %d" % (train_ds.reads_n))
    saver.save(sess,
               FLAGS.log_dir + FLAGS.model_name + '/final.ckpt',
               global_step=global_step_val)