def evaluation(): pbars = multi_pbars( ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model, 'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference(x, seq_length, training=training, full_sequence_len=FLAGS.segment_len, configure=model_configure) predict = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, perm=[1, 0, 2]), seq_length, merge_repeated=False, beam_width=FLAGS.beam) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True sess = tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator(config=config)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath(os.path.join(FLAGS.input, os.path.pardir)) for file in file_list: file_path = os.path.join(file_dir, file) eval_data = read_data_for_eval(file_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n for i in range(0, reads_n, FLAGS.batch_size): batch_x, seq_len, _ = eval_data.next_batch(FLAGS.batch_size, shuffle=False, sig_norm=False) batch_x = np.pad(batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant') seq_len = np.pad(seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant') feed_dict = { x: batch_x, seq_length: np.round(seq_len / ratio).astype(np.int32), training: False, } logits_val, predict_val = sess.run([logits, predict], feed_dict=feed_dict)
def output_list(x, seq_length): training = tf.constant(False, dtype=tf.bool, name='Training') config_path = os.path.join(FLAGS.model, 'model.json') model_configure = read_config(config_path) logits, ratio = inference(x, seq_length, training=training, full_sequence_len=FLAGS.segment_len, configure=model_configure) ratio = tf.constant(ratio, dtype=tf.float32, shape=[]) seq_length_r = tf.cast( tf.round(tf.cast(seq_length, dtype=tf.float32) / ratio), tf.int32) prob_logits = path_prob(logits) predict, log_prob = tf.nn.ctc_beam_search_decoder( tf.transpose(logits, perm=[1, 0, 2]), seq_length_r, merge_repeated=True, beam_width=FLAGS.beam_width) return predict[0], logits, prob_logits, log_prob
def input_output_list(): x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) model_configure = chiron_model.read_config(FLAGS.config_path) logits, _ = inference(x, seq_length, training=training, full_sequence_len=FLAGS.segment_len, configure=model_configure) predict = tf.nn.ctc_greedy_decoder(tf.transpose(logits, perm=[1, 0, 2]), seq_length, merge_repeated=True) input_dict = {'x': x, 'seq_length': seq_length, 'training': training} output_dict = { 'decoded_indices': predict[0][0].indices, 'decoded_values': predict[0][0].values, 'neg_sum_logits': predict[1] } return input_dict, output_dict
def evaluation(): pbars = multi_pbars( ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model, 'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference(x, seq_length, training=training, full_sequence_len=FLAGS.segment_len, configure=model_configure) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True logits_index = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) logits_fname = tf.placeholder(tf.string, shape=[FLAGS.batch_size]) logits_queue = tf.FIFOQueue( capacity=1000, dtypes=[tf.float32, tf.string, tf.int32, tf.int32], shapes=[ logits.shape, logits_fname.shape, logits_index.shape, seq_length.shape ]) logits_queue_size = logits_queue.size() logits_enqueue = logits_queue.enqueue( (logits, logits_fname, logits_index, seq_length)) logits_queue_close = logits_queue.close() ### Decoding logits into bases decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue( logits_queue) saver = tf.train.Saver(var_list=tf.trainable_variables() + tf.moving_average_variables()) with tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator( config=config)) as sess: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) file_n = len(file_list) pbars.update(2, total=file_n) pbars.update(3, total=file_n) if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output, 'segments')): os.makedirs(os.path.join(FLAGS.output, 'segments')) if not os.path.exists(os.path.join(FLAGS.output, 'result')): os.makedirs(os.path.join(FLAGS.output, 'result')) if not os.path.exists(os.path.join(FLAGS.output, 'meta')): os.makedirs(os.path.join(FLAGS.output, 'meta')) def worker_fn(): batch_x = np.asarray([[]]).reshape(0, FLAGS.segment_len) seq_len = np.asarray([]) logits_idx = np.asarray([]) logits_fn = np.asarray([]) for f_i, name in enumerate(file_list): if not name.endswith('.signal'): continue input_path = os.path.join(file_dir, name) eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(0, total=reads_n, progress=0) pbars.update_bar() i = 0 while (eval_data.epochs_completed == 0): current_batch, current_seq_len, _ = eval_data.next_batch( FLAGS.batch_size - len(batch_x), shuffle=False) current_n = len(current_batch) batch_x = np.concatenate((batch_x, current_batch), axis=0) seq_len = np.concatenate((seq_len, current_seq_len), axis=0) logits_idx = np.concatenate((logits_idx, [i] * current_n), axis=0) logits_fn = np.concatenate((logits_fn, [name] * current_n), axis=0) i += current_n if len(batch_x) < FLAGS.batch_size: pbars.update(0, progress=i) pbars.update_bar() continue feed_dict = { x: batch_x, seq_length: np.round(seq_len / ratio).astype(np.int32), training: False, logits_index: logits_idx, logits_fname: logits_fn, } #Training: Set it to True for a temporary fix of the batch normalization problem: https://github.com/haotianteng/Chiron/commit/8fce3a3b4dac8e9027396bb8c9152b7b5af953ce #TODO: change the training FLAG back to False after the new model has been trained. sess.run(logits_enqueue, feed_dict=feed_dict) batch_x = np.asarray([[]]).reshape(0, FLAGS.segment_len) seq_len = np.asarray([]) logits_idx = np.asarray([]) logits_fn = np.asarray([]) pbars.update(0, progress=i) pbars.update_bar() pbars.update(2, progress=f_i + 1) pbars.update_bar() ### All files has been processed. batch_n = len(batch_x) if batch_n > 0: pad_width = FLAGS.batch_size - batch_n batch_x = np.pad(batch_x, ((0, pad_width), (0, 0)), mode='wrap') seq_len = np.pad(seq_len, ((0, pad_width)), mode='wrap') logits_idx = np.pad(logits_idx, (0, pad_width), mode='constant', constant_values=-1) logits_fn = np.pad(logits_fn, (0, pad_width), mode='constant', constant_values='') sess.run(logits_enqueue, feed_dict={ x: batch_x, seq_length: np.round(seq_len / ratio).astype(np.int32), training: False, logits_index: logits_idx, logits_fname: logits_fn, }) sess.run(logits_queue_close) # worker = threading.Thread(target=worker_fn, args=()) worker.setDaemon(True) worker.start() val = defaultdict( dict) # We could read vals out of order, that's why it's a dict for f_i, name in enumerate(file_list): start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(file_dir, name) if FLAGS.mode == 'rna': eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) else: eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(1, total=reads_n, progress=0) pbars.update_bar() reading_time = time.time() - start_time reads = list() if 'total_count' not in val[name].keys(): val[name]['total_count'] = 0 if 'index_list' not in val[name].keys(): val[name]['index_list'] = [] while True: l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size]) if val[name]['total_count'] == reads_n: pbars.update(1, progress=val[name]['total_count']) break decode_ops = [ decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op ] decoded_fname, i, predict_val, logits_prob = sess.run( decode_ops, feed_dict={training: False}) decoded_fname = np.asarray( [x.decode("UTF-8") for x in decoded_fname]) ##Have difficulties integrate it into the tensorflow graph, as the number of file names in a batch is variable. ##And for loop can't be implemented as the eager execution is disabled due to the use of queue. uniq_fname, uniq_fn_idx = np.unique(decoded_fname, return_index=True) for fn_idx, fn in enumerate(uniq_fname): i = uniq_fn_idx[fn_idx] if fn != '': occurance = np.where(decoded_fname == fn)[0] start = occurance[0] end = occurance[-1] + 1 assert (len(occurance) == end - start) if 'total_count' not in val[fn].keys(): val[fn]['total_count'] = 0 if 'index_list' not in val[fn].keys(): val[fn]['index_list'] = [] val[fn]['total_count'] += (end - start) val[fn]['index_list'].append(i) sliced_sparse = slice_ctc_decoding_result( predict_val, start, end) val[fn][i] = (sliced_sparse, logits_prob[decoded_fname == fn]) pbars.update(1, progress=val[name]['total_count']) pbars.update_bar() pbars.update(3, progress=f_i + 1) pbars.update_bar() qs_list = np.empty((0, 1), dtype=np.float) qs_string = None for i in np.sort(val[name]['index_list']): predict_val, logits_prob = val[name][i] predict_read, unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension == 'fastq': logits_prob = logits_prob[unique] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read val.pop(name) # Release the memory basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] js_ratio = FLAGS.jump / FLAGS.segment_len kernal = get_assembler_kernal(FLAGS.jump, FLAGS.segment_len) if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list, js_ratio, kernal=kernal) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads, js_ratio, kernal=kernal) c_bpread = index2base(np.argmax(consensus, axis=0)) assembly_time = time.time() - start_time list_of_time = [ start_time, reading_time, basecall_time, assembly_time ] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string, global_setting=FLAGS) pbars.end()
def chiron_train(): training = tf.placeholder(tf.bool) global_step = tf.get_variable('global_step', trainable=False, shape=(), dtype=tf.int32, initializer=tf.zeros_initializer()) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.sequence_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) y_indexs = tf.placeholder(tf.int64) y_values = tf.placeholder(tf.int32) y_shape = tf.placeholder(tf.int64) y = tf.SparseTensor(y_indexs, y_values, y_shape) default_config = os.path.join(FLAGS.log_dir, FLAGS.model_name, 'model.json') if FLAGS.retrain: if os.path.isfile(default_config): config_file = default_config else: raise ValueError( "Model Json file has not been found in model log directory") else: config_file = FLAGS.configure config = model.read_config(config_file) logits, ratio = model.inference(x, seq_length, training, FLAGS.sequence_len, configure=config) ctc_loss = model.loss(logits, seq_length, y) opt = model.train_opt(FLAGS.step_rate, FLAGS.max_steps, global_step=global_step, opt_name=config['opt_method']) step = opt.minimize(ctc_loss, global_step=global_step) error = model.prediction(logits, seq_length, y) init = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) if FLAGS.retrain == False: sess.run(init) print("Model init finished, begin loading data. \n") else: saver.restore( sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name)) print("Model loaded finished, begin loading data. \n") summary_writer = tf.summary.FileWriter( FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph) model.save_model(default_config, config) train_ds, valid_ds = generate_train_valid_datasets() start = time.time() for i in range(FLAGS.max_steps): batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size) indxs, values, shape = batch_y feed_dict = { x: batch_x, seq_length: seq_len / ratio, y_indexs: indxs, y_values: values, y_shape: shape, training: True } loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict) if i % 10 == 0: global_step_val = tf.train.global_step(sess, global_step) valid_x, valid_len, valid_y = valid_ds.next_batch(FLAGS.batch_size) indxs, values, shape = valid_y feed_dict = { x: valid_x, seq_length: valid_len / ratio, y_indexs: indxs, y_values: values, y_shape: shape, training: True } error_val = sess.run(error, feed_dict=feed_dict) end = time.time() print( "Step %d/%d Epoch %d, batch number %d, train_loss: %5.3f validate_edit_distance: %5.3f Elapsed Time/step: %5.3f" \ % (i, FLAGS.max_steps, train_ds.epochs_completed, train_ds.index_in_epoch, loss_val, error_val, (end - start) / (i + 1))) saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/model.ckpt', global_step=global_step_val) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step=global_step_val) summary_writer.flush() global_step_val = tf.train.global_step(sess, global_step) print("Model %s saved." % (FLAGS.log_dir + FLAGS.model_name)) print("Reads number %d" % (train_ds.reads_n)) saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/final.ckpt', global_step=global_step_val)
def train(): default_config = os.path.join(FLAGS.log_dir, FLAGS.model_name, 'model.json') if FLAGS.retrain: if os.path.isfile(default_config): config_file = default_config else: raise ValueError( "Model Json file has not been found in model log directory") else: config_file = FLAGS.configure config = model.read_config(config_file) print("Begin training using following setting:") with open(os.path.join(FLAGS.log_dir, FLAGS.model_name, 'train_config'), 'w+') as log_f: for pro in dir(FLAGS): if not pro.startswith('_'): print("%s:%s" % (pro, getattr(FLAGS, pro))) log_f.write("%s:%s\n" % (pro, getattr(FLAGS, pro))) net = compile_train_graph(config, FLAGS) sess = tf.Session( config=tf.ConfigProto(inter_op_parallelism_threads=FLAGS.threads, intra_op_parallelism_threads=FLAGS.threads, allow_soft_placement=True)) if FLAGS.retrain == False: sess.run(net.init) print("Model init finished, begin loading data. \n") else: net.saver.restore( sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name)) print("Model loaded finished, begin loading data. \n") summary_writer = tf.summary.FileWriter( FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph) model.save_model(default_config, config) train_ds, valid_ds = generate_train_valid_datasets( initial_offset=DEFAULT_OFFSET) start = time.time() resample_n = 0 for i in range(FLAGS.max_steps): if FLAGS.resample_after_epoch == 0: pass elif train_ds.epochs_completed >= FLAGS.resample_after_epoch: train_ds, valid_ds = generate_train_valid_datasets( initial_offset=resample_n * FLAGS.offset_increment + DEFAULT_OFFSET) batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size) indxs, values, shape = batch_y feed_dict = { net.x: batch_x, net.seq_length: seq_len / net.ratio, net.y_indexs: indxs, net.y_values: values, net.y_shape: shape, net.training: True } loss_val, _ = sess.run([net.ctc_loss, net.step], feed_dict=feed_dict) if i % 10 == 0: global_step_val = tf.train.global_step(sess, net.global_step) valid_x, valid_len, valid_y = valid_ds.next_batch(FLAGS.batch_size) indxs, values, shape = valid_y feed_dict = { net.x: valid_x, net.seq_length: valid_len / net.ratio, net.y_indexs: indxs, net.y_values: values, net.y_shape: shape, net.training: True } error_val = sess.run(net.error, feed_dict=feed_dict) # x_val,errors_val,y_predict,y = sess.run([x,errors,y_,y],feed_dict = feed_dict) # predict_seq,_ = sparse2dense([y_predict,0]) # true_seq,_ = sparse2dense([[y],0]) end = time.time() print( "Step %d/%d Epoch %d, batch number %d, train_loss: %5.3f validate_edit_distance: %5.3f Elapsed Time/step: %5.3f" \ % (i, FLAGS.max_steps, train_ds.epochs_completed, train_ds.index_in_epoch, loss_val, error_val, (end - start) / (i + 1))) net.saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/model.ckpt', global_step=global_step_val) summary_str = sess.run(net.summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step=global_step_val) summary_writer.flush() global_step_val = tf.train.global_step(sess, net.global_step) print("Model %s saved." % (FLAGS.log_dir + FLAGS.model_name)) print("Reads number %d" % (train_ds.reads_n)) net.saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/final.ckpt', global_step=global_step_val)
def evaluation(): logger = logging.getLogger(__name__) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model, 'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference(x, seq_length, training=training, full_sequence_len=FLAGS.segment_len, configure=model_configure) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True tqdm.monitor_interval = 0 logits_index = tf.placeholder(tf.int32, shape=()) logits_fname = tf.placeholder(tf.string, shape=()) logits_queue = tf.FIFOQueue( capacity=1000, dtypes=[tf.float32, tf.string, tf.int32, tf.int32], shapes=[ logits.shape, logits_fname.shape, logits_index.shape, seq_length.shape ]) logits_queue_size = logits_queue.size() logits_enqueue = logits_queue.enqueue( (logits, logits_fname, logits_index, seq_length)) logits_queue_close = logits_queue.close() ### Decoding logits into bases decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue( logits_queue) saver = tf.train.Saver() with tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator( config=config)) as sess: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output, 'segments')): os.makedirs(os.path.join(FLAGS.output, 'segments')) if not os.path.exists(os.path.join(FLAGS.output, 'result')): os.makedirs(os.path.join(FLAGS.output, 'result')) if not os.path.exists(os.path.join(FLAGS.output, 'meta')): os.makedirs(os.path.join(FLAGS.output, 'meta')) def worker_fn(): for name in tqdm(file_list, desc="Logits inferencing.", position=0): if not name.endswith('.signal'): continue input_path = os.path.join(file_dir, name) eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n for i in trange(0, reads_n, FLAGS.batch_size, desc="Logits inferencing", position=1): batch_x, seq_len, _ = eval_data.next_batch( FLAGS.batch_size, shuffle=False, sig_norm=False) batch_x = np.pad(batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant') seq_len = np.pad(seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant') feed_dict = { x: batch_x, seq_length: seq_len, training: False, logits_index: i, logits_fname: name, } sess.run(logits_enqueue, feed_dict=feed_dict) sess.run(logits_queue_close) def run_listener(write_lock): # This function is used to solve the error when tqdm is used inside thread # https://github.com/tqdm/tqdm/issues/323 tqdm.set_lock(write_lock) worker_fn() write_lock = threading.Lock() worker = threading.Thread(target=run_listener, args=(write_lock, )) worker.setDaemon(True) worker.start() val = defaultdict( dict) # We could read vals out of order, that's why it's a dict for name in tqdm(file_list, desc="CTC decoding.", position=2): start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(file_dir, name) if FLAGS.mode == 'rna': eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump, reverse=True) else: eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n reading_time = time.time() - start_time reads = list() N = len(range(0, reads_n, FLAGS.batch_size)) with tqdm(total=reads_n, desc="ctc decoding", position=3) as pbar: while True: options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() l_sz, d_sz = sess.run( [logits_queue_size, decode_queue_size], options=options, run_metadata=run_metadata) pbar.set_postfix(logits_q=l_sz, decoded_q=d_sz, refresh=False) decode_ops = [ decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op ] decoded_fname, i, predict_val, logits_prob = sess.run( decode_ops, feed_dict={training: False}, options=options, run_metadata=run_metadata) decoded_fname = decoded_fname.decode("UTF-8") val[decoded_fname][i] = (predict_val, logits_prob) fetched_timeline = timeline.Timeline( run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format( ) with open('timeline_02_step_%d.json' % i, 'w') as f: f.write(chrome_trace) if decoded_fname == name: decoded_cnt = len(val[name]) pbar.update( min(reads_n, decoded_cnt * FLAGS.batch_size) - (decoded_cnt - 1) * FLAGS.batch_size) if decoded_cnt == N: break qs_list = np.empty((0, 1), dtype=np.float) qs_string = None for i in trange(0, reads_n, FLAGS.batch_size, desc="Output", position=4): predict_val, logits_prob = val[name][i] predict_read, unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension == 'fastq': logits_prob = logits_prob[unique] if i + FLAGS.batch_size > reads_n: predict_read = predict_read[:reads_n - i] if FLAGS.extension == 'fastq': logits_prob = logits_prob[:reads_n - i] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read val.pop(name) # Release the memory # tqdm.write("[%s] Segment reads base calling finished, begin to assembly. %5.2f seconds" % (name, time.time() - start_time)) basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads) c_bpread = index2base(np.argmax(consensus, axis=0)) np.set_printoptions(threshold=np.nan) assembly_time = time.time() - start_time # tqdm.write("[%s] Assembly finished, begin output. %5.2f seconds" % (name, time.time() - start_time)) list_of_time = [ start_time, reading_time, basecall_time, assembly_time ] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string)
def evaluation(): config_path = os.path.join(FLAGS.model, 'model.json') model_configure = chiron_model.read_config(config_path) net = build_eval_graph(model_configure) val = defaultdict( dict) # We could read vals out of order, that's why it's a dict for f_i, name in enumerate(net.file_list): start_time = time.time() if (not name.endswith('.signal')) and (not name.endswith('.fast5')): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(net.file_dir, name) ###Other mode (like methylation) may use different read method. eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump, reverse_fast5=FLAGS.reverse_fast5) reads_n = eval_data.reads_n net.pbars.update(1, total=reads_n, progress=0) net.pbars.update_bar() reading_time = time.time() - start_time reads = list() if 'total_count' not in val[name].keys(): val[name]['total_count'] = 0 if 'index_list' not in val[name].keys(): val[name]['index_list'] = [] while True: l_sz, d_sz = net.sess.run( [net.logits_queue_size, net.decode_queue_size]) if val[name]['total_count'] == reads_n: net.pbars.update(1, progress=val[name]['total_count']) break decode_ops = [ net.decoded_fname_op, net.decode_idx_op, net.decode_predict_op, net.decode_prob_op ] decoded_fname, i, predict_val, logits_prob = net.sess.run( decode_ops, feed_dict={net.training: False}) decoded_fname = np.asarray( [x.decode("UTF-8") for x in decoded_fname]) ##Have difficulties integrate it into the tensorflow graph, as the number of file names in a batch is uncertain. ##And for loop can't be implemented as the eager execution is disabled due to the use of queue. uniq_fname, uniq_fn_idx = np.unique(decoded_fname, return_index=True) for fn_idx, fn in enumerate(uniq_fname): i = uniq_fn_idx[fn_idx] if fn != '': occurance = np.where(decoded_fname == fn)[0] start = occurance[0] end = occurance[-1] + 1 assert (len(occurance) == end - start) if 'total_count' not in val[fn].keys(): val[fn]['total_count'] = 0 if 'index_list' not in val[fn].keys(): val[fn]['index_list'] = [] val[fn]['total_count'] += (end - start) val[fn]['index_list'].append(i) sliced_sparse = slice_ctc_decoding_result( predict_val, start, end) val[fn][i] = (sliced_sparse, logits_prob[decoded_fname == fn]) net.pbars.update(1, progress=val[name]['total_count']) net.pbars.update_bar() net.pbars.update(3, progress=f_i + 1) net.pbars.update_bar() qs_list = np.empty((0, 1), dtype=np.float) qs_string = None for i in np.sort(val[name]['index_list']): predict_val, logits_prob = val[name][i] predict_read, unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension == 'fastq': logits_prob = logits_prob[unique] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read val.pop(name) # Release the memory basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] js_ratio = FLAGS.jump / FLAGS.segment_len kernal = get_assembler_kernal(FLAGS.jump, FLAGS.segment_len) if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list, js_ratio, kernal=kernal) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads, js_ratio, kernal=kernal) c_bpread = index2base(np.argmax(consensus, axis=0)) assembly_time = time.time() - start_time list_of_time = [start_time, reading_time, basecall_time, assembly_time] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string, global_setting=FLAGS) net.pbars.end()
def train(hparams): """Main training function. This will train a Neural Network with the given dataset. Args: hparams: hyper parameter for training the neural network data-dir: String, the path of the data(binary batch files) directory. log-dir: String, the path to save the trained model. sequence-len: Int, length of input signal. batch-size: Int. step-rate: Float, step rate of the optimizer. max-steps: Int, max training steps. kmer: Int, size of the dna kmer. model-name: String, model will be saved at log-dir/model-name. retrain: Boolean, if True, the model will be reload from log-dir/model-name. """ with tf.Graph().as_default(), tf.device('/cpu:0'): training = tf.placeholder(tf.bool) global_step = tf.get_variable('global_step', trainable=False, shape=(), dtype=tf.int32, initializer=tf.zeros_initializer()) opt = model.train_opt(hparams.step_rate, hparams.max_steps, global_step=global_step) x, seq_length, train_labels = inputs(hparams.data_dir, int(hparams.batch_size * hparams.ngpus), for_valid=False) split_y = tf.split(train_labels, hparams.ngpus, axis=0) split_seq_length = tf.split(seq_length, hparams.ngpus, axis=0) split_x = tf.split(x, hparams.ngpus, axis=0) tower_grads = [] default_config = os.path.join(hparams.log_dir, hparams.model_name, 'model.json') if hparams.retrain: if os.path.isfile(default_config): config_file = default_config else: raise ValueError( "Model Json file has not been found in model log directory" ) else: config_file = hparams.configure config = model.read_config(config_file) with tf.variable_scope(tf.get_variable_scope()): for i in range(hparams.ngpus): with tf.device('/gpu:%d' % i): with tf.name_scope('%s_%d' % ('gpu_tower', i)) as scope: loss, error = tower_loss( scope, split_x[i], split_seq_length[i], split_y[i], full_seq_len=hparams.sequence_len, config=config) tf.get_variable_scope().reuse_variables() summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) grads = opt.compute_gradients(loss) tower_grads.append(grads) grads = average_gradients(tower_grads) for grad, var in grads: if grad is not None: summaries.append( tf.summary.histogram(var.op.name + '/gradients', grad)) for var in tf.trainable_variables(): summaries.append(tf.summary.histogram(var.op.name, var)) apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) var_averages = tf.train.ExponentialMovingAverage( decay=model.MOVING_AVERAGE_DECAY) var_averages_op = var_averages.apply(tf.trainable_variables()) train_op = tf.group(apply_gradient_op, var_averages_op) init = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) model.save_model(default_config, config) if not hparams.retrain: sess.run(init) print("Model init finished, begin training. \n") else: saver.restore( sess, tf.train.latest_checkpoint(hparams.log_dir + hparams.model_name)) print("Model loaded finished, begin training. \n") summary_writer = tf.summary.FileWriter( hparams.log_dir + hparams.model_name + '/summary/', sess.graph) _ = tf.train.start_queue_runners(sess=sess) start = time.time() for i in range(hparams.max_steps): feed_dict = {training: True} loss_val, _ = sess.run([loss, train_op], feed_dict=feed_dict) if i % 10 == 0: global_step_val = tf.train.global_step(sess, global_step) feed_dict = {training: True} error_val = sess.run(error, feed_dict=feed_dict) end = time.time() print( "Step %d/%d , loss: %5.3f edit_distance: %5.3f Elapsed Time/batch: %5.3f" \ % (i, hparams.max_steps, loss_val, error_val, (end - start) / (i + 1))) saver.save(sess, hparams.log_dir + hparams.model_name + '/model.ckpt', global_step=global_step_val) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step=global_step_val) summary_writer.flush() global_step_val = tf.train.global_step(sess, global_step) print("Model %s saved." % (hparams.log_dir + hparams.model_name)) saver.save(sess, hparams.log_dir + hparams.model_name + '/final.ckpt', global_step=global_step_val)
def evaluation(): pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model,'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference( x, seq_length, training=training, full_sequence_len = FLAGS.segment_len, configure = model_configure) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True logits_index = tf.placeholder(tf.int32, shape=()) logits_fname = tf.placeholder(tf.string, shape=()) logits_queue = tf.FIFOQueue( capacity=1000, dtypes=[tf.float32, tf.string, tf.int32, tf.int32], shapes=[logits.shape, logits_fname.shape, logits_index.shape, seq_length.shape] ) logits_queue_size = logits_queue.size() logits_enqueue = logits_queue.enqueue((logits, logits_fname, logits_index, seq_length)) logits_queue_close = logits_queue.close() ### Decoding logits into bases decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(logits_queue) saver = tf.train.Saver() with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(config=config)) as sess: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) file_n = len(file_list) pbars.update(2,total = file_n) pbars.update(3,total = file_n) if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output, 'segments')): os.makedirs(os.path.join(FLAGS.output, 'segments')) if not os.path.exists(os.path.join(FLAGS.output, 'result')): os.makedirs(os.path.join(FLAGS.output, 'result')) if not os.path.exists(os.path.join(FLAGS.output, 'meta')): os.makedirs(os.path.join(FLAGS.output, 'meta')) def worker_fn(): for f_i, name in enumerate(file_list): if not name.endswith('.signal'): continue input_path = os.path.join(file_dir, name) eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(0,total = reads_n,progress = 0) pbars.update_bar() for i in range(0, reads_n, FLAGS.batch_size): batch_x, seq_len, _ = eval_data.next_batch( FLAGS.batch_size, shuffle=False, sig_norm=False) batch_x = np.pad( batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant') seq_len = np.pad( seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant') feed_dict = { x: batch_x, seq_length: seq_len, training: False, logits_index:i, logits_fname: name, }
def train(hparam): """Main training function. This will train a Neural Network with the given dataset. Args: hparam: hyper parameter for training the neural network data_dir: String, the path of the data(binary batch files) directory. log-dir: String, the path to save the trained model. sequence-len: Int, length of input signal. batch-size: Int. step-rate: Float, step rate of the optimizer. max-steps: Int, max training steps. kmer: Int, size of the dna kmer. model-name: String, model will be saved at log-dir/model-name. retrain: Boolean, if True, the model will be reload from log-dir/model-name. """ training = tf.placeholder(tf.bool) global_step = tf.get_variable('global_step', trainable=False, shape=(), dtype=tf.int32, initializer=tf.zeros_initializer()) x, seq_length, train_labels = inputs(hparam.data_dir, hparam.batch_size, hparam.sequence_len, for_valid=False) y = dense2sparse(train_labels) default_config = os.path.join(hparam.log_dir, hparam.model_name, 'model.json') if hparam.retrain: if os.path.isfile(default_config): config_file = default_config else: raise ValueError( "Model Json file has not been found in model log directory") else: config_file = hparam.configure config = model.read_config(config_file) logits, ratio = model.inference(x, seq_length, training, hparam.sequence_len, configure=config, apply_ratio=True) seq_length = tf.cast(tf.ceil(tf.cast(seq_length, tf.float32) / ratio), tf.int32) ctc_loss = model.loss(logits, seq_length, y) opt = model.train_opt(hparam.step_rate, hparam.max_steps, global_step=global_step) step = opt.minimize(ctc_loss, global_step=global_step) error = model.prediction(logits, seq_length, y) init = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) model.save_model(default_config, config) if not hparam.retrain: sess.run(init) print("Model init finished, begin training. \n") else: saver.restore( sess, tf.train.latest_checkpoint(hparam.log_dir + hparam.model_name)) print("Model loaded finished, begin training. \n") summary_writer = tf.summary.FileWriter( hparam.log_dir + hparam.model_name + '/summary/', sess.graph) _ = tf.train.start_queue_runners(sess=sess) start = time.time() for i in range(hparam.max_steps): feed_dict = {training: True} loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict) if i % 10 == 0: global_step_val = tf.train.global_step(sess, global_step) feed_dict = {training: True} error_val = sess.run(error, feed_dict=feed_dict) end = time.time() print( "Step %d/%d , loss: %5.3f edit_distance: %5.3f Elapsed Time/batch: %5.3f" \ % (i, hparam.max_steps, loss_val, error_val, (end - start) / (i + 1))) saver.save(sess, hparam.log_dir + hparam.model_name + '/model.ckpt', global_step=global_step_val) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step=global_step_val) summary_writer.flush() global_step_val = tf.train.global_step(sess, global_step) print("Model %s saved." % (hparam.log_dir + hparam.model_name)) saver.save(sess, hparam.log_dir + hparam.model_name + '/final.ckpt', global_step=global_step_val)
def evaluation(): pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model,'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference( x, seq_length, training=training, full_sequence_len = FLAGS.segment_len, configure = model_configure) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True logits_index = tf.placeholder(tf.int32, shape=()) logits_fname = tf.placeholder(tf.string, shape=()) logits_queue = tf.FIFOQueue( capacity=1000, dtypes=[tf.float32, tf.string, tf.int32, tf.int32], shapes=[logits.shape,logits_fname.shape,logits_index.shape, seq_length.shape] ) logits_queue_size = logits_queue.size() logits_enqueue = logits_queue.enqueue((logits, logits_fname, logits_index, seq_length)) logits_queue_close = logits_queue.close() ### Decoding logits into bases decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(logits_queue) saver = tf.train.Saver() with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(config=config)) as sess: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) file_n = len(file_list) pbars.update(2,total = file_n) pbars.update(3,total = file_n) if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output, 'segments')): os.makedirs(os.path.join(FLAGS.output, 'segments')) if not os.path.exists(os.path.join(FLAGS.output, 'result')): os.makedirs(os.path.join(FLAGS.output, 'result')) if not os.path.exists(os.path.join(FLAGS.output, 'meta')): os.makedirs(os.path.join(FLAGS.output, 'meta')) def worker_fn(): for f_i, name in enumerate(file_list): if not name.endswith('.signal'): continue input_path = os.path.join(file_dir, name) eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(0,total = reads_n,progress = 0) pbars.update_bar() for i in range(0, reads_n, FLAGS.batch_size): batch_x, seq_len, _ = eval_data.next_batch( FLAGS.batch_size, shuffle=False, sig_norm=False) batch_x = np.pad( batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant') seq_len = np.pad( seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant') feed_dict = { x: batch_x, seq_length: np.round(seq_len/ratio).astype(np.int32), training: False, logits_index:i, logits_fname: name, } sess.run(logits_enqueue,feed_dict=feed_dict) pbars.update(0,progress=i+FLAGS.batch_size) pbars.update_bar() pbars.update(2,progress = f_i+1) pbars.update_bar() sess.run(logits_queue_close) worker = threading.Thread(target=worker_fn,args=() ) worker.setDaemon(True) worker.start() val = defaultdict(dict) # We could read vals out of order, that's why it's a dict for f_i, name in enumerate(file_list): start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(file_dir, name) if FLAGS.mode == 'rna': eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) else: eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(1,total = reads_n,progress = 0) pbars.update_bar() reading_time = time.time() - start_time reads = list() N = len(range(0, reads_n, FLAGS.batch_size)) while True: l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size]) decode_ops = [decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op] decoded_fname, i, predict_val, logits_prob = sess.run(decode_ops, feed_dict={training: False}) decoded_fname = decoded_fname.decode("UTF-8") val[decoded_fname][i] = (predict_val, logits_prob) pbars.update(1,progress = len(val[name])*FLAGS.batch_size) pbars.update_bar() if len(val[name]) == N: break pbars.update(3,progress = f_i+1) pbars.update_bar() qs_list = np.empty((0, 1), dtype=np.float) qs_string = None for i in range(0, reads_n, FLAGS.batch_size): predict_val, logits_prob = val[name][i] predict_read, unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension == 'fastq': logits_prob = logits_prob[unique] if i + FLAGS.batch_size > reads_n: predict_read = predict_read[:reads_n - i] if FLAGS.extension == 'fastq': logits_prob = logits_prob[:reads_n - i] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read val.pop(name) # Release the memory basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] js_ratio = FLAGS.jump/FLAGS.segment_len if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list,js_ratio) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads,js_ratio) c_bpread = index2base(np.argmax(consensus, axis=0)) assembly_time = time.time() - start_time list_of_time = [start_time, reading_time, basecall_time, assembly_time] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string,global_setting=FLAGS) pbars.end()
action='store_true', help= "Concisely output the result, the meta and segments files will not be output." ) parser.add_argument('--mode', default='dna', help="Output mode, can be chosen from dna or rna.") FLAGS = parser.parse_args(sys.argv[1:]) pbars = multi_pbars( ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model, 'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference(x, seq_length, training=training, full_sequence_len=FLAGS.segment_len, configure=model_configure) prorbs = tf.nn.softmax(logits) predict = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, perm=[1, 0, 2]), seq_length, merge_repeated=False, beam_width=FLAGS.beam) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True