Example #1
0
def train():
    train_ds,valid_ds = read_raw_data_sets(FLAGS.data_dir,FLAGS.sequence_len,valid_reads_num = 10000)
    x = tf.placeholder(tf.float32,shape = [FLAGS.batch_size,FLAGS.sequence_len])
    seq_length = tf.placeholder(tf.int32, shape = [FLAGS.batch_size])
    y_indexs = tf.placeholder(tf.int64)
    y_values = tf.placeholder(tf.int32)
    y_shape = tf.placeholder(tf.int64)
    y = tf.SparseTensor(y_indexs,y_values,y_shape)
    
    logits = inference(x)
    ctc_loss = loss(logits,seq_length,y)
    opt = train_step(ctc_loss)
    error = prediction(logits,seq_length,y)
    
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.device('/gpu:0'):
        sess = tf.Session()
        sess.run(init)
        #saver.restore(sess,tf.train.latest_checkpoint('log/cnn'))
        for i in range(FLAGS.max_steps):
            batch_x,seq_len,batch_y = train_ds.next_batch(FLAGS.batch_size)
            indxs,values,shape = batch_y
            loss_val,_ = sess.run([ctc_loss,opt],feed_dict = {x:batch_x,seq_length:seq_len,y_indexs:indxs,y_values:values,y_shape:shape})
            if i%10 ==0:
                error_val = sess.run(error,feed_dict = {x:batch_x,seq_length:seq_len,y_indexs:indxs,y_values:values,y_shape:shape})
                print("Epoch %d, batch number %d, loss: %5.2f edit_distance: %5.2f"\
                %(train_ds.epochs_completed,train_ds.index_in_epoch,loss_val,error_val))
                saver.save(sess,"log/cnn/model.ckpt",i)
        saver.save(sess,"log/cnn/final.ckpt")
Example #2
0
def train(hparam):
    training = tf.placeholder(tf.bool)
    global_step=tf.get_variable('global_step',trainable=False,shape=(),dtype = tf.int32,initializer = tf.zeros_initializer())
    x = tf.placeholder(tf.float32,shape = [hparam.batch_size,hparam.sequence_len])
    seq_length = tf.placeholder(tf.int32, shape = [hparam.batch_size])
    y_indexs = tf.placeholder(tf.int64)
    y_values = tf.placeholder(tf.int32)
    y_shape = tf.placeholder(tf.int64)
    y = tf.SparseTensor(y_indexs,y_values,y_shape)
    logits,ratio = inference(x,seq_length,training)
    ctc_loss = loss(logits,seq_length,y)
    opt = train_step(ctc_loss, hparam.step_rate, global_step = global_step)
    error = prediction(logits,seq_length,y)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()
    
    sess = tf.Session(config = tf.ConfigProto(allow_soft_placement=True))
    save_model(hparam.log_dir, hparam.model_name)
    if hparam.retrain==False:
        sess.run(init)
        print("Model init finished, begin loading data. \n")
    else:
        saver.restore(sess,tf.train.latest_checkpoint(hparam.log_dir+hparam.model_name))
        print("Model loaded finished, begin loading data. \n")
    summary_writer = tf.summary.FileWriter(hparam.log_dir+hparam.model_name+'/summary/', sess.graph)
    
    train_ds = read_raw_data_sets(hparam.data_dir,hparam.cache_dir,hparam.sequence_len,k_mer=hparam.kmer)
    start=time.time()
    for i in range(hparam.max_steps):
        batch_x,seq_len,batch_y = train_ds.next_batch(hparam.batch_size)
        indxs,values,shape = batch_y
        feed_dict =  {x:batch_x,seq_length:seq_len/ratio,y_indexs:indxs,y_values:values,y_shape:shape,training:True}
        loss_val,_ = sess.run([ctc_loss,opt],feed_dict = feed_dict)
        if i%10 ==0:
	    global_step_val = tf.train.global_step(sess,global_step)
            valid_x,valid_len,valid_y = train_ds.next_batch(hparam.batch_size)
            indxs,values,shape = valid_y
            feed_dict = {x:valid_x,seq_length:valid_len/ratio,y_indexs:indxs,y_values:values,y_shape:shape,training:True}
            error_val = sess.run(error,feed_dict = feed_dict)
	    end = time.time()
            print "Step %d/%d Epoch %d, batch number %d, loss: %5.3f edit_distance: %5.3f Elapsed Time/step: %5.3f"\
            %(i,hparam.max_steps,train_ds.epochs_completed,train_ds.index_in_epoch,loss_val,error_val,(end-start)/(i+1))
            saver.save(sess,hparam.log_dir+hparam.model_name+'/model.ckpt',global_step=global_step_val)
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str, global_step = global_step_val)
            summary_writer.flush()
    global_step_val = tf.train.global_step(sess,global_step)
    print "Model %s saved."%(hparam.log_dir+hparam.model_name)
    print "Reads number %d"%(train_ds.reads_n)       
    saver.save(sess,hparam.log_dir+hparam.model_name+'/final.ckpt',global_step=global_step_val)
Example #3
0
def train():
    training = tf.placeholder(tf.bool)
    x = tf.placeholder(tf.float32,
                       shape=[FLAGS.batch_size, FLAGS.sequence_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    y_indexs = tf.placeholder(tf.int64)
    y_values = tf.placeholder(tf.int32)
    y_shape = tf.placeholder(tf.int64)
    y = tf.SparseTensor(y_indexs, y_values, y_shape)
    logits, ratio = inference(x, seq_length, training)
    ctc_loss = loss(logits, seq_length, y)
    opt = train_step(ctc_loss)
    error = prediction(logits, seq_length, y)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    save_model()
    if FLAGS.retrain == False:
        sess.run(init)
        print("Model init finished, begin loading data. \n")
    else:
        saver.restore(
            sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name))
        print("Model loaded finished, begin loading data. \n")
    summary_writer = tf.summary.FileWriter(
        FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph)

    train_ds, valid_ds = read_raw_data_sets(FLAGS.data_dir,
                                            FLAGS.sequence_len,
                                            valid_reads_num=10000,
                                            k_mer=FLAGS.k_mer)

    for i in range(FLAGS.max_steps):
        batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size)
        indxs, values, shape = batch_y
        feed_dict = {
            x: batch_x,
            seq_length: seq_len / ratio,
            y_indexs: indxs,
            y_values: values,
            y_shape: shape,
            training: True
        }
        loss_val, _ = sess.run([ctc_loss, opt], feed_dict=feed_dict)
        if i % 10 == 0:
            valid_x, valid_len, valid_y = valid_ds.next_batch(FLAGS.batch_size)
            indxs, values, shape = valid_y
            feed_dict = {
                x: valid_x,
                seq_length: valid_len / ratio,
                y_indexs: indxs,
                y_values: values,
                y_shape: shape,
                training: True
            }
            error_val = sess.run(error, feed_dict=feed_dict)
            print "Epoch %d, batch number %d, loss: %5.3f edit_distance: %5.3f"\
            %(train_ds.epochs_completed,train_ds.index_in_epoch,loss_val,error_val)
            saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/model.ckpt',
                       i)
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str, i)
            summary_writer.flush()

    saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/final.ckpt')
def train():
    LOG.info('Building model')
    training = tf.placeholder(tf.bool)
    global_step = tf.get_variable('global_step',
                                  trainable=False,
                                  shape=(),
                                  dtype=tf.int32,
                                  initializer=tf.zeros_initializer())
    x = tf.placeholder(tf.float32,
                       shape=[FLAGS.batch_size, FLAGS.sequence_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    y_indexs = tf.placeholder(tf.int64)
    y_values = tf.placeholder(tf.int32)
    y_shape = tf.placeholder(tf.int64)
    y = tf.SparseTensor(y_indexs, y_values, y_shape)
    logits, ratio = inference(x, seq_length, training)
    ctc_loss = loss(logits, seq_length, y)
    opt = train_step(ctc_loss, global_step=global_step)
    error = prediction(logits, seq_length, y)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    save_model()
    if FLAGS.retrain == False:
        LOG.info('Initializing model')
        sess.run(init)
    else:
        LOG.info('Restoring model')
        saver.restore(
            sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name))
    summary_writer = tf.summary.FileWriter(
        FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph)

    LOG.info('Reading data')
    train_ds = read_raw_data_sets(FLAGS.data_dir,
                                  FLAGS.cache_dir,
                                  FLAGS.sequence_len,
                                  k_mer=FLAGS.k_mer,
                                  max_files=FLAGS.max_files,
                                  max_reads=FLAGS.max_samples,
                                  log=LOG.info)

    num_samples = train_ds.reads_n
    LOG.info('Number of training samples: %d' % num_samples)

    LOG.info('Starting training')
    start = time.time()
    num_steps = FLAGS.max_steps

    for i in range(num_steps):
        batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size)
        indxs, values, shape = batch_y
        feed_dict = {
            x: batch_x,
            seq_length: seq_len / ratio,
            y_indexs: indxs,
            y_values: values,
            y_shape: shape,
            training: True
        }
        loss_val, _ = sess.run([ctc_loss, opt], feed_dict=feed_dict)
        if i % 10 == 0:
            global_step_val = tf.train.global_step(sess, global_step)
            valid_x, valid_len, valid_y = train_ds.next_batch(FLAGS.batch_size)
            indxs, values, shape = valid_y
            feed_dict = {
                x: valid_x,
                seq_length: valid_len / ratio,
                y_indexs: indxs,
                y_values: values,
                y_shape: shape,
                training: True
            }
            error_val = sess.run(error, feed_dict=feed_dict)
            end = time.time()
            msg = "Step %d/%d Epoch %d, batch number %d, loss: %5.3f edit_distance: %5.3f Elapsed Time/step: %5.3f"
            msg %= (i, num_steps, train_ds.epochs_completed,
                    train_ds.index_in_epoch, loss_val, error_val,
                    (end - start) / (i + 1))
            print(msg)
            saver.save(sess,
                       FLAGS.log_dir + FLAGS.model_name + '/model.ckpt',
                       global_step=global_step_val)
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str,
                                       global_step=global_step_val)
            summary_writer.flush()

    chk_path = os.path.join(FLAGS.log_dir, FLAGS.model_name, 'final.ckpt')
    LOG.info('Saving final model to "%s"' % chk_path)
    saver.save(sess, chk_path, global_step=global_step_val)
    global_step_val = tf.train.global_step(sess, global_step)

    LOG.info('Number of processed reads: %d' % train_ds.reads_n)
Example #5
0
def prediction(logits, seq_length, label):
    logits = tf.transpose(logits, perm=[1, 0, 2])
    """ctc_beam_search_decoder require input shape [max_time,batch_size,num_classes]"""
    predict = tf.to_int32(
        tf.nn.ctc_beam_search_decoder(logits, seq_length,
                                      merge_repeated=False)[0][0])
    error = tf.reduce_sum(tf.edit_distance(
        predict, label, normalize=False)) / tf.to_float(tf.size(label.values))
    tf.summary.scalar('Error_rate', error)
    return error


"""Copy the train function here"""
train_ds, valid_ds = read_raw_data_sets(FLAGS.data_dir,
                                        FLAGS.sequence_len,
                                        valid_reads_num=100)
with tf.device('/gpu:0'):
    training = tf.placeholder(tf.bool)
    x = tf.placeholder(tf.float32,
                       shape=[FLAGS.batch_size, FLAGS.sequence_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    y_indexs = tf.placeholder(tf.int64)
    y_values = tf.placeholder(tf.int32)
    y_shape = tf.placeholder(tf.int64)
    y = tf.SparseTensor(y_indexs, y_values, y_shape)
    logits = inference(x, seq_length, training)
    ctc_loss = loss(logits, seq_length, y)
    opt = train_step(ctc_loss)
    error = prediction(logits, seq_length, y)
Example #6
0
def train():
    training = tf.placeholder(tf.bool)
    global_step = tf.get_variable('global_step',
                                  trainable=False,
                                  shape=(),
                                  dtype=tf.int32,
                                  initializer=tf.zeros_initializer())
    x = tf.placeholder(tf.float32,
                       shape=[FLAGS.batch_size, FLAGS.sequence_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    y_indexs = tf.placeholder(tf.int64)
    y_values = tf.placeholder(tf.int32)
    y_shape = tf.placeholder(tf.int64)
    y = tf.SparseTensor(y_indexs, y_values, y_shape)
    logits, ratio = inference(x, seq_length, training)
    ctc_loss = loss(logits, seq_length, y)
    opt = train_step(ctc_loss, global_step=global_step)
    error = prediction(logits, seq_length, y)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    #    save_model()
    if FLAGS.retrain == False:
        sess.run(init)
        print("Model init finished, begin loading data. \n")
    else:
        saver.restore(
            sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name))
        print("Model loaded finished, begin loading data. \n")
    summary_writer = tf.summary.FileWriter(
        FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph)

    h5py_file_path = FLAGS.cache_dir
    if h5py_file_path is None:
        h5py_file_path = tempfile.mkdtemp() + '/temp_record.hdf5'
    else:
        try:
            os.remove(os.path.abspath(h5py_file_path))
        except:
            pass
        if not os.path.isdir(os.path.dirname(os.path.abspath(h5py_file_path))):
            os.mkdir(os.path.dirname(os.path.abspath(h5py_file_path)))

    hdf5_record = h5py.File(h5py_file_path, "a")

    train_ds = read_raw_data_sets(FLAGS.data_dir, hdf5_record,
                                  FLAGS.sequence_len, FLAGS.k_mer,
                                  FLAGS.alphabet, FLAGS.jump,
                                  FLAGS.smooth_window, FLAGS.skip_step,
                                  FLAGS.normalize)
    start = time.time()
    for i in range(FLAGS.max_steps):
        batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size)
        indxs, values, shape = batch_y
        feed_dict = {
            x: batch_x,
            seq_length: seq_len / ratio,
            y_indexs: indxs,
            y_values: values,
            y_shape: shape,
            training: True
        }
        loss_val, _ = sess.run([ctc_loss, opt], feed_dict=feed_dict)
        if i % 10 == 0:
            global_step_val = tf.train.global_step(sess, global_step)
            valid_x, valid_len, valid_y = train_ds.next_batch(FLAGS.batch_size)
            indxs, values, shape = valid_y
            feed_dict = {
                x: valid_x,
                seq_length: valid_len / ratio,
                y_indexs: indxs,
                y_values: values,
                y_shape: shape,
                training: True
            }
            error_val = sess.run(error, feed_dict=feed_dict)
            end = time.time()
            print "Step %d/%d Epoch %d, batch number %d, loss: %5.3f edit_distance: %5.3f Elapsed Time/step: %5.3f"\
            %(i,FLAGS.max_steps,train_ds.epochs_completed,train_ds.index_in_epoch,loss_val,error_val,(end-start)/(i+1))
            saver.save(sess,
                       FLAGS.log_dir + FLAGS.model_name + '/model.ckpt',
                       global_step=global_step_val)
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str,
                                       global_step=global_step_val)
            summary_writer.flush()
    global_step_val = tf.train.global_step(sess, global_step)
    print "Model %s saved." % (FLAGS.log_dir + FLAGS.model_name)
    print "Reads number %d" % (train_ds.reads_n)
    saver.save(sess,
               FLAGS.log_dir + FLAGS.model_name + '/final.ckpt',
               global_step=global_step_val)
    hdf5_record.close()
    os.remove(os.path.abspath(h5py_file_path))