def process_output(self, graph_output, input_path): """Process output from prediciton function""" # print(graph_output) name = os.path.splitext(os.path.basename(input_path))[0] fasta_out_path = os.path.join(self.inference_output_dir, name+".fasta") all_reads = [] for batch in graph_output: all_reads.extend([SignalLabel.index2base(read) for read in batch]) concensus = simple_assembly(all_reads) c_bpread = SignalLabel.index2base(np.argmax(concensus, axis=0)) with open(fasta_out_path, 'w+') as fasta_f: fasta_f.write(">{}\n{}\n".format(name, c_bpread))
def evaluation(): pbars = multi_pbars( ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model, 'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference(x, seq_length, training=training, full_sequence_len=FLAGS.segment_len, configure=model_configure) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True logits_index = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) logits_fname = tf.placeholder(tf.string, shape=[FLAGS.batch_size]) logits_queue = tf.FIFOQueue( capacity=1000, dtypes=[tf.float32, tf.string, tf.int32, tf.int32], shapes=[ logits.shape, logits_fname.shape, logits_index.shape, seq_length.shape ]) logits_queue_size = logits_queue.size() logits_enqueue = logits_queue.enqueue( (logits, logits_fname, logits_index, seq_length)) logits_queue_close = logits_queue.close() ### Decoding logits into bases decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue( logits_queue) saver = tf.train.Saver(var_list=tf.trainable_variables() + tf.moving_average_variables()) with tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator( config=config)) as sess: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) file_n = len(file_list) pbars.update(2, total=file_n) pbars.update(3, total=file_n) if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output, 'segments')): os.makedirs(os.path.join(FLAGS.output, 'segments')) if not os.path.exists(os.path.join(FLAGS.output, 'result')): os.makedirs(os.path.join(FLAGS.output, 'result')) if not os.path.exists(os.path.join(FLAGS.output, 'meta')): os.makedirs(os.path.join(FLAGS.output, 'meta')) def worker_fn(): batch_x = np.asarray([[]]).reshape(0, FLAGS.segment_len) seq_len = np.asarray([]) logits_idx = np.asarray([]) logits_fn = np.asarray([]) for f_i, name in enumerate(file_list): if not name.endswith('.signal'): continue input_path = os.path.join(file_dir, name) eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(0, total=reads_n, progress=0) pbars.update_bar() i = 0 while (eval_data.epochs_completed == 0): current_batch, current_seq_len, _ = eval_data.next_batch( FLAGS.batch_size - len(batch_x), shuffle=False) current_n = len(current_batch) batch_x = np.concatenate((batch_x, current_batch), axis=0) seq_len = np.concatenate((seq_len, current_seq_len), axis=0) logits_idx = np.concatenate((logits_idx, [i] * current_n), axis=0) logits_fn = np.concatenate((logits_fn, [name] * current_n), axis=0) i += current_n if len(batch_x) < FLAGS.batch_size: pbars.update(0, progress=i) pbars.update_bar() continue feed_dict = { x: batch_x, seq_length: np.round(seq_len / ratio).astype(np.int32), training: False, logits_index: logits_idx, logits_fname: logits_fn, } #Training: Set it to True for a temporary fix of the batch normalization problem: https://github.com/haotianteng/Chiron/commit/8fce3a3b4dac8e9027396bb8c9152b7b5af953ce #TODO: change the training FLAG back to False after the new model has been trained. sess.run(logits_enqueue, feed_dict=feed_dict) batch_x = np.asarray([[]]).reshape(0, FLAGS.segment_len) seq_len = np.asarray([]) logits_idx = np.asarray([]) logits_fn = np.asarray([]) pbars.update(0, progress=i) pbars.update_bar() pbars.update(2, progress=f_i + 1) pbars.update_bar() ### All files has been processed. batch_n = len(batch_x) if batch_n > 0: pad_width = FLAGS.batch_size - batch_n batch_x = np.pad(batch_x, ((0, pad_width), (0, 0)), mode='wrap') seq_len = np.pad(seq_len, ((0, pad_width)), mode='wrap') logits_idx = np.pad(logits_idx, (0, pad_width), mode='constant', constant_values=-1) logits_fn = np.pad(logits_fn, (0, pad_width), mode='constant', constant_values='') sess.run(logits_enqueue, feed_dict={ x: batch_x, seq_length: np.round(seq_len / ratio).astype(np.int32), training: False, logits_index: logits_idx, logits_fname: logits_fn, }) sess.run(logits_queue_close) # worker = threading.Thread(target=worker_fn, args=()) worker.setDaemon(True) worker.start() val = defaultdict( dict) # We could read vals out of order, that's why it's a dict for f_i, name in enumerate(file_list): start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(file_dir, name) if FLAGS.mode == 'rna': eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) else: eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(1, total=reads_n, progress=0) pbars.update_bar() reading_time = time.time() - start_time reads = list() if 'total_count' not in val[name].keys(): val[name]['total_count'] = 0 if 'index_list' not in val[name].keys(): val[name]['index_list'] = [] while True: l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size]) if val[name]['total_count'] == reads_n: pbars.update(1, progress=val[name]['total_count']) break decode_ops = [ decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op ] decoded_fname, i, predict_val, logits_prob = sess.run( decode_ops, feed_dict={training: False}) decoded_fname = np.asarray( [x.decode("UTF-8") for x in decoded_fname]) ##Have difficulties integrate it into the tensorflow graph, as the number of file names in a batch is variable. ##And for loop can't be implemented as the eager execution is disabled due to the use of queue. uniq_fname, uniq_fn_idx = np.unique(decoded_fname, return_index=True) for fn_idx, fn in enumerate(uniq_fname): i = uniq_fn_idx[fn_idx] if fn != '': occurance = np.where(decoded_fname == fn)[0] start = occurance[0] end = occurance[-1] + 1 assert (len(occurance) == end - start) if 'total_count' not in val[fn].keys(): val[fn]['total_count'] = 0 if 'index_list' not in val[fn].keys(): val[fn]['index_list'] = [] val[fn]['total_count'] += (end - start) val[fn]['index_list'].append(i) sliced_sparse = slice_ctc_decoding_result( predict_val, start, end) val[fn][i] = (sliced_sparse, logits_prob[decoded_fname == fn]) pbars.update(1, progress=val[name]['total_count']) pbars.update_bar() pbars.update(3, progress=f_i + 1) pbars.update_bar() qs_list = np.empty((0, 1), dtype=np.float) qs_string = None for i in np.sort(val[name]['index_list']): predict_val, logits_prob = val[name][i] predict_read, unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension == 'fastq': logits_prob = logits_prob[unique] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read val.pop(name) # Release the memory basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] js_ratio = FLAGS.jump / FLAGS.segment_len kernal = get_assembler_kernal(FLAGS.jump, FLAGS.segment_len) if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list, js_ratio, kernal=kernal) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads, js_ratio, kernal=kernal) c_bpread = index2base(np.argmax(consensus, axis=0)) assembly_time = time.time() - start_time list_of_time = [ start_time, reading_time, basecall_time, assembly_time ] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string, global_setting=FLAGS) pbars.end()
def evaluation(): logger = logging.getLogger(__name__) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model, 'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference(x, seq_length, training=training, full_sequence_len=FLAGS.segment_len, configure=model_configure) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True tqdm.monitor_interval = 0 logits_index = tf.placeholder(tf.int32, shape=()) logits_fname = tf.placeholder(tf.string, shape=()) logits_queue = tf.FIFOQueue( capacity=1000, dtypes=[tf.float32, tf.string, tf.int32, tf.int32], shapes=[ logits.shape, logits_fname.shape, logits_index.shape, seq_length.shape ]) logits_queue_size = logits_queue.size() logits_enqueue = logits_queue.enqueue( (logits, logits_fname, logits_index, seq_length)) logits_queue_close = logits_queue.close() ### Decoding logits into bases decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue( logits_queue) saver = tf.train.Saver() with tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator( config=config)) as sess: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output, 'segments')): os.makedirs(os.path.join(FLAGS.output, 'segments')) if not os.path.exists(os.path.join(FLAGS.output, 'result')): os.makedirs(os.path.join(FLAGS.output, 'result')) if not os.path.exists(os.path.join(FLAGS.output, 'meta')): os.makedirs(os.path.join(FLAGS.output, 'meta')) def worker_fn(): for name in tqdm(file_list, desc="Logits inferencing.", position=0): if not name.endswith('.signal'): continue input_path = os.path.join(file_dir, name) eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n for i in trange(0, reads_n, FLAGS.batch_size, desc="Logits inferencing", position=1): batch_x, seq_len, _ = eval_data.next_batch( FLAGS.batch_size, shuffle=False, sig_norm=False) batch_x = np.pad(batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant') seq_len = np.pad(seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant') feed_dict = { x: batch_x, seq_length: seq_len, training: False, logits_index: i, logits_fname: name, } sess.run(logits_enqueue, feed_dict=feed_dict) sess.run(logits_queue_close) def run_listener(write_lock): # This function is used to solve the error when tqdm is used inside thread # https://github.com/tqdm/tqdm/issues/323 tqdm.set_lock(write_lock) worker_fn() write_lock = threading.Lock() worker = threading.Thread(target=run_listener, args=(write_lock, )) worker.setDaemon(True) worker.start() val = defaultdict( dict) # We could read vals out of order, that's why it's a dict for name in tqdm(file_list, desc="CTC decoding.", position=2): start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(file_dir, name) if FLAGS.mode == 'rna': eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump, reverse=True) else: eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n reading_time = time.time() - start_time reads = list() N = len(range(0, reads_n, FLAGS.batch_size)) with tqdm(total=reads_n, desc="ctc decoding", position=3) as pbar: while True: options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() l_sz, d_sz = sess.run( [logits_queue_size, decode_queue_size], options=options, run_metadata=run_metadata) pbar.set_postfix(logits_q=l_sz, decoded_q=d_sz, refresh=False) decode_ops = [ decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op ] decoded_fname, i, predict_val, logits_prob = sess.run( decode_ops, feed_dict={training: False}, options=options, run_metadata=run_metadata) decoded_fname = decoded_fname.decode("UTF-8") val[decoded_fname][i] = (predict_val, logits_prob) fetched_timeline = timeline.Timeline( run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format( ) with open('timeline_02_step_%d.json' % i, 'w') as f: f.write(chrome_trace) if decoded_fname == name: decoded_cnt = len(val[name]) pbar.update( min(reads_n, decoded_cnt * FLAGS.batch_size) - (decoded_cnt - 1) * FLAGS.batch_size) if decoded_cnt == N: break qs_list = np.empty((0, 1), dtype=np.float) qs_string = None for i in trange(0, reads_n, FLAGS.batch_size, desc="Output", position=4): predict_val, logits_prob = val[name][i] predict_read, unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension == 'fastq': logits_prob = logits_prob[unique] if i + FLAGS.batch_size > reads_n: predict_read = predict_read[:reads_n - i] if FLAGS.extension == 'fastq': logits_prob = logits_prob[:reads_n - i] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read val.pop(name) # Release the memory # tqdm.write("[%s] Segment reads base calling finished, begin to assembly. %5.2f seconds" % (name, time.time() - start_time)) basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads) c_bpread = index2base(np.argmax(consensus, axis=0)) np.set_printoptions(threshold=np.nan) assembly_time = time.time() - start_time # tqdm.write("[%s] Assembly finished, begin output. %5.2f seconds" % (name, time.time() - start_time)) list_of_time = [ start_time, reading_time, basecall_time, assembly_time ] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string)
def evaluation(): config_path = os.path.join(FLAGS.model, 'model.json') model_configure = chiron_model.read_config(config_path) net = build_eval_graph(model_configure) val = defaultdict( dict) # We could read vals out of order, that's why it's a dict for f_i, name in enumerate(net.file_list): start_time = time.time() if (not name.endswith('.signal')) and (not name.endswith('.fast5')): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(net.file_dir, name) ###Other mode (like methylation) may use different read method. eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump, reverse_fast5=FLAGS.reverse_fast5) reads_n = eval_data.reads_n net.pbars.update(1, total=reads_n, progress=0) net.pbars.update_bar() reading_time = time.time() - start_time reads = list() if 'total_count' not in val[name].keys(): val[name]['total_count'] = 0 if 'index_list' not in val[name].keys(): val[name]['index_list'] = [] while True: l_sz, d_sz = net.sess.run( [net.logits_queue_size, net.decode_queue_size]) if val[name]['total_count'] == reads_n: net.pbars.update(1, progress=val[name]['total_count']) break decode_ops = [ net.decoded_fname_op, net.decode_idx_op, net.decode_predict_op, net.decode_prob_op ] decoded_fname, i, predict_val, logits_prob = net.sess.run( decode_ops, feed_dict={net.training: False}) decoded_fname = np.asarray( [x.decode("UTF-8") for x in decoded_fname]) ##Have difficulties integrate it into the tensorflow graph, as the number of file names in a batch is uncertain. ##And for loop can't be implemented as the eager execution is disabled due to the use of queue. uniq_fname, uniq_fn_idx = np.unique(decoded_fname, return_index=True) for fn_idx, fn in enumerate(uniq_fname): i = uniq_fn_idx[fn_idx] if fn != '': occurance = np.where(decoded_fname == fn)[0] start = occurance[0] end = occurance[-1] + 1 assert (len(occurance) == end - start) if 'total_count' not in val[fn].keys(): val[fn]['total_count'] = 0 if 'index_list' not in val[fn].keys(): val[fn]['index_list'] = [] val[fn]['total_count'] += (end - start) val[fn]['index_list'].append(i) sliced_sparse = slice_ctc_decoding_result( predict_val, start, end) val[fn][i] = (sliced_sparse, logits_prob[decoded_fname == fn]) net.pbars.update(1, progress=val[name]['total_count']) net.pbars.update_bar() net.pbars.update(3, progress=f_i + 1) net.pbars.update_bar() qs_list = np.empty((0, 1), dtype=np.float) qs_string = None for i in np.sort(val[name]['index_list']): predict_val, logits_prob = val[name][i] predict_read, unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension == 'fastq': logits_prob = logits_prob[unique] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read val.pop(name) # Release the memory basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] js_ratio = FLAGS.jump / FLAGS.segment_len kernal = get_assembler_kernal(FLAGS.jump, FLAGS.segment_len) if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list, js_ratio, kernal=kernal) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads, js_ratio, kernal=kernal) c_bpread = index2base(np.argmax(consensus, axis=0)) assembly_time = time.time() - start_time list_of_time = [start_time, reading_time, basecall_time, assembly_time] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string, global_setting=FLAGS) net.pbars.end()
if i + FLAGS.batch_size > reads_n: predict_read = predict_read[:reads_n - i] if FLAGS.extension == 'fastq': logits_prob = logits_prob[:reads_n - i] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read val.pop(name) # Release the memory basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads) c_bpread = index2base(np.argmax(consensus, axis=0)) assembly_time = time.time() - start_time list_of_time = [start_time, reading_time, basecall_time, assembly_time] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string) pbars.end() def decoding_queue(logits_queue, num_threads=6): q_logits, q_name, q_index, seq_length = logits_queue.dequeue() if FLAGS.extension == 'fastq': prob = path_prob(q_logits) else: prob = tf.constant(0.0) # We just need to have the right type, because of the queues
def evaluation(): pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model,'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference( x, seq_length, training=training, full_sequence_len = FLAGS.segment_len, configure = model_configure) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True logits_index = tf.placeholder(tf.int32, shape=()) logits_fname = tf.placeholder(tf.string, shape=()) logits_queue = tf.FIFOQueue( capacity=1000, dtypes=[tf.float32, tf.string, tf.int32, tf.int32], shapes=[logits.shape,logits_fname.shape,logits_index.shape, seq_length.shape] ) logits_queue_size = logits_queue.size() logits_enqueue = logits_queue.enqueue((logits, logits_fname, logits_index, seq_length)) logits_queue_close = logits_queue.close() ### Decoding logits into bases decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(logits_queue) saver = tf.train.Saver() with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(config=config)) as sess: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) file_n = len(file_list) pbars.update(2,total = file_n) pbars.update(3,total = file_n) if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output, 'segments')): os.makedirs(os.path.join(FLAGS.output, 'segments')) if not os.path.exists(os.path.join(FLAGS.output, 'result')): os.makedirs(os.path.join(FLAGS.output, 'result')) if not os.path.exists(os.path.join(FLAGS.output, 'meta')): os.makedirs(os.path.join(FLAGS.output, 'meta')) def worker_fn(): for f_i, name in enumerate(file_list): if not name.endswith('.signal'): continue input_path = os.path.join(file_dir, name) eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(0,total = reads_n,progress = 0) pbars.update_bar() for i in range(0, reads_n, FLAGS.batch_size): batch_x, seq_len, _ = eval_data.next_batch( FLAGS.batch_size, shuffle=False, sig_norm=False) batch_x = np.pad( batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant') seq_len = np.pad( seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant') feed_dict = { x: batch_x, seq_length: np.round(seq_len/ratio).astype(np.int32), training: False, logits_index:i, logits_fname: name, } sess.run(logits_enqueue,feed_dict=feed_dict) pbars.update(0,progress=i+FLAGS.batch_size) pbars.update_bar() pbars.update(2,progress = f_i+1) pbars.update_bar() sess.run(logits_queue_close) worker = threading.Thread(target=worker_fn,args=() ) worker.setDaemon(True) worker.start() val = defaultdict(dict) # We could read vals out of order, that's why it's a dict for f_i, name in enumerate(file_list): start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(file_dir, name) if FLAGS.mode == 'rna': eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) else: eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(1,total = reads_n,progress = 0) pbars.update_bar() reading_time = time.time() - start_time reads = list() N = len(range(0, reads_n, FLAGS.batch_size)) while True: l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size]) decode_ops = [decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op] decoded_fname, i, predict_val, logits_prob = sess.run(decode_ops, feed_dict={training: False}) decoded_fname = decoded_fname.decode("UTF-8") val[decoded_fname][i] = (predict_val, logits_prob) pbars.update(1,progress = len(val[name])*FLAGS.batch_size) pbars.update_bar() if len(val[name]) == N: break pbars.update(3,progress = f_i+1) pbars.update_bar() qs_list = np.empty((0, 1), dtype=np.float) qs_string = None for i in range(0, reads_n, FLAGS.batch_size): predict_val, logits_prob = val[name][i] predict_read, unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension == 'fastq': logits_prob = logits_prob[unique] if i + FLAGS.batch_size > reads_n: predict_read = predict_read[:reads_n - i] if FLAGS.extension == 'fastq': logits_prob = logits_prob[:reads_n - i] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read val.pop(name) # Release the memory basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] js_ratio = FLAGS.jump/FLAGS.segment_len if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list,js_ratio) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads,js_ratio) c_bpread = index2base(np.argmax(consensus, axis=0)) assembly_time = time.time() - start_time list_of_time = [start_time, reading_time, basecall_time, assembly_time] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string,global_setting=FLAGS) pbars.end()
def evaluation(): logger = logging.getLogger(__name__) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) logits, ratio = chiron_model.inference(x, seq_length, training=training, full_sequence_len=FLAGS.segment_len) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True logits_index = tf.placeholder(tf.int32, shape=()) logits_queue = tf.FIFOQueue( capacity=FLAGS.batch_size * 100, dtypes=[tf.float32, tf.int32, tf.int32], shapes=[logits.shape, logits_index.shape, seq_length.shape]) logits_queue_size = logits_queue.size() logits_enqueue = logits_queue.enqueue((logits, logits_index, seq_length)) ### Decoding logits into bases decode_predict_op, decode_prob_op, decode_idx_op, decode_queue_size = decoding_queue( logits_queue) saver = tf.train.Saver() with tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator( config=config)) as sess: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output, 'segments')): os.makedirs(os.path.join(FLAGS.output, 'segments')) if not os.path.exists(os.path.join(FLAGS.output, 'result')): os.makedirs(os.path.join(FLAGS.output, 'result')) if not os.path.exists(os.path.join(FLAGS.output, 'meta')): os.makedirs(os.path.join(FLAGS.output, 'meta')) for name in tqdm(file_list, desc="basecalling fast5s"): start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(file_dir, name) if FLAGS.mode == 'rna': eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump, reverse=True) else: eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n reading_time = time.time() - start_time reads = list() N = len(range(0, reads_n, FLAGS.batch_size)) i_logits = 0 decoded_cnt = 0 val = {} # We could read vals out of order, that's why it's a dict with tqdm(total=reads_n, desc="signal processing") as pbar: while decoded_cnt < N: l_sz, d_sz = sess.run( [logits_queue_size, decode_queue_size]) # Flow control # Either we have something beam decoded, or we've pushed all data into the queue pbar.set_postfix(logits_q=l_sz, decoded_q=d_sz) if d_sz > 0 or i_logits >= reads_n: i, predict_val, logits_prob = sess.run( [decode_idx_op, decode_predict_op, decode_prob_op], feed_dict={training: False}) val[i] = (predict_val, logits_prob) decoded_cnt += 1 pbar.update( min(reads_n, decoded_cnt * FLAGS.batch_size) - (decoded_cnt - 1) * FLAGS.batch_size) else: batch_x, seq_len, _ = eval_data.next_batch( FLAGS.batch_size, shuffle=False, sig_norm=False) batch_x = np.pad(batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant') seq_len = np.pad( seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant') feed_dict = { x: batch_x, seq_length: seq_len / ratio, training: False, logits_index: i_logits } sess.run(logits_enqueue, feed_dict=feed_dict) i_logits += FLAGS.batch_size qs_list = np.empty((0, 1), dtype=np.float) qs_string = None for i in trange(0, reads_n, FLAGS.batch_size, desc="further decoding"): predict_val, logits_prob = val[i] predict_read, unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension == 'fastq': logits_prob = logits_prob[unique] if i + FLAGS.batch_size > reads_n: predict_read = predict_read[:reads_n - i] if FLAGS.extension == 'fastq': logits_prob = logits_prob[:reads_n - i] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read tqdm.write( "[%s] Segment reads base calling finished, begin to assembly. %5.2f seconds" % (name, time.time() - start_time)) basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads) c_bpread = index2base(np.argmax(consensus, axis=0)) np.set_printoptions(threshold=np.nan) assembly_time = time.time() - start_time tqdm.write("[%s] Assembly finished, begin output. %5.2f seconds" % (name, time.time() - start_time)) list_of_time = [ start_time, reading_time, basecall_time, assembly_time ] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string)