def train(): train_ds,valid_ds = read_raw_data_sets(FLAGS.data_dir,FLAGS.sequence_len,valid_reads_num = 10000) x = tf.placeholder(tf.float32,shape = [FLAGS.batch_size,FLAGS.sequence_len]) seq_length = tf.placeholder(tf.int32, shape = [FLAGS.batch_size]) y_indexs = tf.placeholder(tf.int64) y_values = tf.placeholder(tf.int32) y_shape = tf.placeholder(tf.int64) y = tf.SparseTensor(y_indexs,y_values,y_shape) logits = inference(x) ctc_loss = loss(logits,seq_length,y) opt = train_step(ctc_loss) error = prediction(logits,seq_length,y) init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.device('/gpu:0'): sess = tf.Session() sess.run(init) #saver.restore(sess,tf.train.latest_checkpoint('log/cnn')) for i in range(FLAGS.max_steps): batch_x,seq_len,batch_y = train_ds.next_batch(FLAGS.batch_size) indxs,values,shape = batch_y loss_val,_ = sess.run([ctc_loss,opt],feed_dict = {x:batch_x,seq_length:seq_len,y_indexs:indxs,y_values:values,y_shape:shape}) if i%10 ==0: error_val = sess.run(error,feed_dict = {x:batch_x,seq_length:seq_len,y_indexs:indxs,y_values:values,y_shape:shape}) print("Epoch %d, batch number %d, loss: %5.2f edit_distance: %5.2f"\ %(train_ds.epochs_completed,train_ds.index_in_epoch,loss_val,error_val)) saver.save(sess,"log/cnn/model.ckpt",i) saver.save(sess,"log/cnn/final.ckpt")
def train(hparam): training = tf.placeholder(tf.bool) global_step=tf.get_variable('global_step',trainable=False,shape=(),dtype = tf.int32,initializer = tf.zeros_initializer()) x = tf.placeholder(tf.float32,shape = [hparam.batch_size,hparam.sequence_len]) seq_length = tf.placeholder(tf.int32, shape = [hparam.batch_size]) y_indexs = tf.placeholder(tf.int64) y_values = tf.placeholder(tf.int32) y_shape = tf.placeholder(tf.int64) y = tf.SparseTensor(y_indexs,y_values,y_shape) logits,ratio = inference(x,seq_length,training) ctc_loss = loss(logits,seq_length,y) opt = train_step(ctc_loss, hparam.step_rate, global_step = global_step) error = prediction(logits,seq_length,y) init = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() sess = tf.Session(config = tf.ConfigProto(allow_soft_placement=True)) save_model(hparam.log_dir, hparam.model_name) if hparam.retrain==False: sess.run(init) print("Model init finished, begin loading data. \n") else: saver.restore(sess,tf.train.latest_checkpoint(hparam.log_dir+hparam.model_name)) print("Model loaded finished, begin loading data. \n") summary_writer = tf.summary.FileWriter(hparam.log_dir+hparam.model_name+'/summary/', sess.graph) train_ds = read_raw_data_sets(hparam.data_dir,hparam.cache_dir,hparam.sequence_len,k_mer=hparam.kmer) start=time.time() for i in range(hparam.max_steps): batch_x,seq_len,batch_y = train_ds.next_batch(hparam.batch_size) indxs,values,shape = batch_y feed_dict = {x:batch_x,seq_length:seq_len/ratio,y_indexs:indxs,y_values:values,y_shape:shape,training:True} loss_val,_ = sess.run([ctc_loss,opt],feed_dict = feed_dict) if i%10 ==0: global_step_val = tf.train.global_step(sess,global_step) valid_x,valid_len,valid_y = train_ds.next_batch(hparam.batch_size) indxs,values,shape = valid_y feed_dict = {x:valid_x,seq_length:valid_len/ratio,y_indexs:indxs,y_values:values,y_shape:shape,training:True} error_val = sess.run(error,feed_dict = feed_dict) end = time.time() print "Step %d/%d Epoch %d, batch number %d, loss: %5.3f edit_distance: %5.3f Elapsed Time/step: %5.3f"\ %(i,hparam.max_steps,train_ds.epochs_completed,train_ds.index_in_epoch,loss_val,error_val,(end-start)/(i+1)) saver.save(sess,hparam.log_dir+hparam.model_name+'/model.ckpt',global_step=global_step_val) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step = global_step_val) summary_writer.flush() global_step_val = tf.train.global_step(sess,global_step) print "Model %s saved."%(hparam.log_dir+hparam.model_name) print "Reads number %d"%(train_ds.reads_n) saver.save(sess,hparam.log_dir+hparam.model_name+'/final.ckpt',global_step=global_step_val)
def train(): training = tf.placeholder(tf.bool) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.sequence_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) y_indexs = tf.placeholder(tf.int64) y_values = tf.placeholder(tf.int32) y_shape = tf.placeholder(tf.int64) y = tf.SparseTensor(y_indexs, y_values, y_shape) logits, ratio = inference(x, seq_length, training) ctc_loss = loss(logits, seq_length, y) opt = train_step(ctc_loss) error = prediction(logits, seq_length, y) init = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) save_model() if FLAGS.retrain == False: sess.run(init) print("Model init finished, begin loading data. \n") else: saver.restore( sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name)) print("Model loaded finished, begin loading data. \n") summary_writer = tf.summary.FileWriter( FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph) train_ds, valid_ds = read_raw_data_sets(FLAGS.data_dir, FLAGS.sequence_len, valid_reads_num=10000, k_mer=FLAGS.k_mer) for i in range(FLAGS.max_steps): batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size) indxs, values, shape = batch_y feed_dict = { x: batch_x, seq_length: seq_len / ratio, y_indexs: indxs, y_values: values, y_shape: shape, training: True } loss_val, _ = sess.run([ctc_loss, opt], feed_dict=feed_dict) if i % 10 == 0: valid_x, valid_len, valid_y = valid_ds.next_batch(FLAGS.batch_size) indxs, values, shape = valid_y feed_dict = { x: valid_x, seq_length: valid_len / ratio, y_indexs: indxs, y_values: values, y_shape: shape, training: True } error_val = sess.run(error, feed_dict=feed_dict) print "Epoch %d, batch number %d, loss: %5.3f edit_distance: %5.3f"\ %(train_ds.epochs_completed,train_ds.index_in_epoch,loss_val,error_val) saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/model.ckpt', i) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, i) summary_writer.flush() saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/final.ckpt')
def train(): LOG.info('Building model') training = tf.placeholder(tf.bool) global_step = tf.get_variable('global_step', trainable=False, shape=(), dtype=tf.int32, initializer=tf.zeros_initializer()) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.sequence_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) y_indexs = tf.placeholder(tf.int64) y_values = tf.placeholder(tf.int32) y_shape = tf.placeholder(tf.int64) y = tf.SparseTensor(y_indexs, y_values, y_shape) logits, ratio = inference(x, seq_length, training) ctc_loss = loss(logits, seq_length, y) opt = train_step(ctc_loss, global_step=global_step) error = prediction(logits, seq_length, y) init = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) save_model() if FLAGS.retrain == False: LOG.info('Initializing model') sess.run(init) else: LOG.info('Restoring model') saver.restore( sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name)) summary_writer = tf.summary.FileWriter( FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph) LOG.info('Reading data') train_ds = read_raw_data_sets(FLAGS.data_dir, FLAGS.cache_dir, FLAGS.sequence_len, k_mer=FLAGS.k_mer, max_files=FLAGS.max_files, max_reads=FLAGS.max_samples, log=LOG.info) num_samples = train_ds.reads_n LOG.info('Number of training samples: %d' % num_samples) LOG.info('Starting training') start = time.time() num_steps = FLAGS.max_steps for i in range(num_steps): batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size) indxs, values, shape = batch_y feed_dict = { x: batch_x, seq_length: seq_len / ratio, y_indexs: indxs, y_values: values, y_shape: shape, training: True } loss_val, _ = sess.run([ctc_loss, opt], feed_dict=feed_dict) if i % 10 == 0: global_step_val = tf.train.global_step(sess, global_step) valid_x, valid_len, valid_y = train_ds.next_batch(FLAGS.batch_size) indxs, values, shape = valid_y feed_dict = { x: valid_x, seq_length: valid_len / ratio, y_indexs: indxs, y_values: values, y_shape: shape, training: True } error_val = sess.run(error, feed_dict=feed_dict) end = time.time() msg = "Step %d/%d Epoch %d, batch number %d, loss: %5.3f edit_distance: %5.3f Elapsed Time/step: %5.3f" msg %= (i, num_steps, train_ds.epochs_completed, train_ds.index_in_epoch, loss_val, error_val, (end - start) / (i + 1)) print(msg) saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/model.ckpt', global_step=global_step_val) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step=global_step_val) summary_writer.flush() chk_path = os.path.join(FLAGS.log_dir, FLAGS.model_name, 'final.ckpt') LOG.info('Saving final model to "%s"' % chk_path) saver.save(sess, chk_path, global_step=global_step_val) global_step_val = tf.train.global_step(sess, global_step) LOG.info('Number of processed reads: %d' % train_ds.reads_n)
def prediction(logits, seq_length, label): logits = tf.transpose(logits, perm=[1, 0, 2]) """ctc_beam_search_decoder require input shape [max_time,batch_size,num_classes]""" predict = tf.to_int32( tf.nn.ctc_beam_search_decoder(logits, seq_length, merge_repeated=False)[0][0]) error = tf.reduce_sum(tf.edit_distance( predict, label, normalize=False)) / tf.to_float(tf.size(label.values)) tf.summary.scalar('Error_rate', error) return error """Copy the train function here""" train_ds, valid_ds = read_raw_data_sets(FLAGS.data_dir, FLAGS.sequence_len, valid_reads_num=100) with tf.device('/gpu:0'): training = tf.placeholder(tf.bool) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.sequence_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) y_indexs = tf.placeholder(tf.int64) y_values = tf.placeholder(tf.int32) y_shape = tf.placeholder(tf.int64) y = tf.SparseTensor(y_indexs, y_values, y_shape) logits = inference(x, seq_length, training) ctc_loss = loss(logits, seq_length, y) opt = train_step(ctc_loss) error = prediction(logits, seq_length, y)
def train(): training = tf.placeholder(tf.bool) global_step = tf.get_variable('global_step', trainable=False, shape=(), dtype=tf.int32, initializer=tf.zeros_initializer()) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.sequence_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) y_indexs = tf.placeholder(tf.int64) y_values = tf.placeholder(tf.int32) y_shape = tf.placeholder(tf.int64) y = tf.SparseTensor(y_indexs, y_values, y_shape) logits, ratio = inference(x, seq_length, training) ctc_loss = loss(logits, seq_length, y) opt = train_step(ctc_loss, global_step=global_step) error = prediction(logits, seq_length, y) init = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # save_model() if FLAGS.retrain == False: sess.run(init) print("Model init finished, begin loading data. \n") else: saver.restore( sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name)) print("Model loaded finished, begin loading data. \n") summary_writer = tf.summary.FileWriter( FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph) h5py_file_path = FLAGS.cache_dir if h5py_file_path is None: h5py_file_path = tempfile.mkdtemp() + '/temp_record.hdf5' else: try: os.remove(os.path.abspath(h5py_file_path)) except: pass if not os.path.isdir(os.path.dirname(os.path.abspath(h5py_file_path))): os.mkdir(os.path.dirname(os.path.abspath(h5py_file_path))) hdf5_record = h5py.File(h5py_file_path, "a") train_ds = read_raw_data_sets(FLAGS.data_dir, hdf5_record, FLAGS.sequence_len, FLAGS.k_mer, FLAGS.alphabet, FLAGS.jump, FLAGS.smooth_window, FLAGS.skip_step, FLAGS.normalize) start = time.time() for i in range(FLAGS.max_steps): batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size) indxs, values, shape = batch_y feed_dict = { x: batch_x, seq_length: seq_len / ratio, y_indexs: indxs, y_values: values, y_shape: shape, training: True } loss_val, _ = sess.run([ctc_loss, opt], feed_dict=feed_dict) if i % 10 == 0: global_step_val = tf.train.global_step(sess, global_step) valid_x, valid_len, valid_y = train_ds.next_batch(FLAGS.batch_size) indxs, values, shape = valid_y feed_dict = { x: valid_x, seq_length: valid_len / ratio, y_indexs: indxs, y_values: values, y_shape: shape, training: True } error_val = sess.run(error, feed_dict=feed_dict) end = time.time() print "Step %d/%d Epoch %d, batch number %d, loss: %5.3f edit_distance: %5.3f Elapsed Time/step: %5.3f"\ %(i,FLAGS.max_steps,train_ds.epochs_completed,train_ds.index_in_epoch,loss_val,error_val,(end-start)/(i+1)) saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/model.ckpt', global_step=global_step_val) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step=global_step_val) summary_writer.flush() global_step_val = tf.train.global_step(sess, global_step) print "Model %s saved." % (FLAGS.log_dir + FLAGS.model_name) print "Reads number %d" % (train_ds.reads_n) saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/final.ckpt', global_step=global_step_val) hdf5_record.close() os.remove(os.path.abspath(h5py_file_path))