def __init__(self,configure): self.pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"]) self.x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) self.seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) self.training = tf.placeholder(tf.bool) self.logits, self.ratio = chiron_model.inference( self.x, self.seq_length, training=self.training, full_sequence_len = FLAGS.segment_len, configure = configure) self.config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) self.config.gpu_options.allow_growth = True self.logits_index = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) self.logits_fname = tf.placeholder(tf.string, shape=[FLAGS.batch_size]) self.logits_queue = tf.FIFOQueue( capacity=1000, dtypes=[tf.float32, tf.string, tf.int32, tf.int32], shapes=[self.logits.shape,self.logits_fname.shape,self.logits_index.shape, self.seq_length.shape] ) self.logits_queue_size = self.logits_queue.size() self.logits_enqueue = self.logits_queue.enqueue((self.logits, self.logits_fname, self.logits_index, self.seq_length)) self.logits_queue_close = self.logits_queue.close() ### Decoding logits into bases self.decode_predict_op, self.decode_prob_op, self.decoded_fname_op, self.decode_idx_op, self.decode_queue_size = decoding_queue(self.logits_queue) self.saver = tf.train.Saver(var_list=tf.trainable_variables()+tf.moving_average_variables())
def evaluation(): pbars = multi_pbars( ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model, 'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference(x, seq_length, training=training, full_sequence_len=FLAGS.segment_len, configure=model_configure) predict = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, perm=[1, 0, 2]), seq_length, merge_repeated=False, beam_width=FLAGS.beam) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True sess = tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator(config=config)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath(os.path.join(FLAGS.input, os.path.pardir)) for file in file_list: file_path = os.path.join(file_dir, file) eval_data = read_data_for_eval(file_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n for i in range(0, reads_n, FLAGS.batch_size): batch_x, seq_len, _ = eval_data.next_batch(FLAGS.batch_size, shuffle=False, sig_norm=False) batch_x = np.pad(batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant') seq_len = np.pad(seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant') feed_dict = { x: batch_x, seq_length: np.round(seq_len / ratio).astype(np.int32), training: False, } logits_val, predict_val = sess.run([logits, predict], feed_dict=feed_dict)
def extract(root_folder,output_folder,raw_folder=None): global logger error_bars = multi_pbars([""]*5) run_record = Counter() batch_i = 1 if not os.path.isdir(root_folder): raise IOError('Input directory does not found.') batch_folder = make_batch_folder(output_folder,batch_i) for dir_n,_,file_list in tf.gfile.Walk(root_folder): for file_n in file_list: if file_n.endswith('fast5'): file_prefix = file_n.split('.')[0] # output_file = output_folder + os.path.splitext(file_n)[0] file_n = os.path.join(dir_n,file_n) state, (raw_data, raw_data_array),(offset,digitisation,range_s) = extract_file(file_n) run_record[state] +=1 if run_record[SUCCEED_TAG]>batch_i*FLAGS.batch: batch_i+=1 batch_folder = make_batch_folder(output_folder,batch_i) common_errors = run_record.most_common(FLAGS.n_errors) total_errors = sum(run_record.values()) for i in np.arange(min(FLAGS.n_errors,len(common_errors))): error_bars.update(i, title = common_errors[i][0], progress = common_errors[i][1], total = total_errors) error_bars.refresh() if state == SUCCEED_TAG: if FLAGS.unit: raw_data=reunit(raw_data,offset,digitisation,range_s) with open(os.path.join(batch_folder,file_prefix+'.signal'),'w+') as f: f.write('\n'.join([str(x) for x in raw_data])) with open(os.path.join(batch_folder,file_prefix+'.label'),'w+') as f: for label in raw_data_array: f.write(' '.join([str(x) for x in label])) f.write('\n') logger.info("%s file transfered. \n" % (file_n)) else: logger.error("FAIL on %s file, because of error %s. \n" % (file_n,state))
def evaluation(): pbars = multi_pbars( ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model, 'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference(x, seq_length, training=training, full_sequence_len=FLAGS.segment_len, configure=model_configure) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True logits_index = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) logits_fname = tf.placeholder(tf.string, shape=[FLAGS.batch_size]) logits_queue = tf.FIFOQueue( capacity=1000, dtypes=[tf.float32, tf.string, tf.int32, tf.int32], shapes=[ logits.shape, logits_fname.shape, logits_index.shape, seq_length.shape ]) logits_queue_size = logits_queue.size() logits_enqueue = logits_queue.enqueue( (logits, logits_fname, logits_index, seq_length)) logits_queue_close = logits_queue.close() ### Decoding logits into bases decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue( logits_queue) saver = tf.train.Saver(var_list=tf.trainable_variables() + tf.moving_average_variables()) with tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator( config=config)) as sess: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) file_n = len(file_list) pbars.update(2, total=file_n) pbars.update(3, total=file_n) if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output, 'segments')): os.makedirs(os.path.join(FLAGS.output, 'segments')) if not os.path.exists(os.path.join(FLAGS.output, 'result')): os.makedirs(os.path.join(FLAGS.output, 'result')) if not os.path.exists(os.path.join(FLAGS.output, 'meta')): os.makedirs(os.path.join(FLAGS.output, 'meta')) def worker_fn(): batch_x = np.asarray([[]]).reshape(0, FLAGS.segment_len) seq_len = np.asarray([]) logits_idx = np.asarray([]) logits_fn = np.asarray([]) for f_i, name in enumerate(file_list): if not name.endswith('.signal'): continue input_path = os.path.join(file_dir, name) eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(0, total=reads_n, progress=0) pbars.update_bar() i = 0 while (eval_data.epochs_completed == 0): current_batch, current_seq_len, _ = eval_data.next_batch( FLAGS.batch_size - len(batch_x), shuffle=False) current_n = len(current_batch) batch_x = np.concatenate((batch_x, current_batch), axis=0) seq_len = np.concatenate((seq_len, current_seq_len), axis=0) logits_idx = np.concatenate((logits_idx, [i] * current_n), axis=0) logits_fn = np.concatenate((logits_fn, [name] * current_n), axis=0) i += current_n if len(batch_x) < FLAGS.batch_size: pbars.update(0, progress=i) pbars.update_bar() continue feed_dict = { x: batch_x, seq_length: np.round(seq_len / ratio).astype(np.int32), training: False, logits_index: logits_idx, logits_fname: logits_fn, } #Training: Set it to True for a temporary fix of the batch normalization problem: https://github.com/haotianteng/Chiron/commit/8fce3a3b4dac8e9027396bb8c9152b7b5af953ce #TODO: change the training FLAG back to False after the new model has been trained. sess.run(logits_enqueue, feed_dict=feed_dict) batch_x = np.asarray([[]]).reshape(0, FLAGS.segment_len) seq_len = np.asarray([]) logits_idx = np.asarray([]) logits_fn = np.asarray([]) pbars.update(0, progress=i) pbars.update_bar() pbars.update(2, progress=f_i + 1) pbars.update_bar() ### All files has been processed. batch_n = len(batch_x) if batch_n > 0: pad_width = FLAGS.batch_size - batch_n batch_x = np.pad(batch_x, ((0, pad_width), (0, 0)), mode='wrap') seq_len = np.pad(seq_len, ((0, pad_width)), mode='wrap') logits_idx = np.pad(logits_idx, (0, pad_width), mode='constant', constant_values=-1) logits_fn = np.pad(logits_fn, (0, pad_width), mode='constant', constant_values='') sess.run(logits_enqueue, feed_dict={ x: batch_x, seq_length: np.round(seq_len / ratio).astype(np.int32), training: False, logits_index: logits_idx, logits_fname: logits_fn, }) sess.run(logits_queue_close) # worker = threading.Thread(target=worker_fn, args=()) worker.setDaemon(True) worker.start() val = defaultdict( dict) # We could read vals out of order, that's why it's a dict for f_i, name in enumerate(file_list): start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(file_dir, name) if FLAGS.mode == 'rna': eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) else: eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(1, total=reads_n, progress=0) pbars.update_bar() reading_time = time.time() - start_time reads = list() if 'total_count' not in val[name].keys(): val[name]['total_count'] = 0 if 'index_list' not in val[name].keys(): val[name]['index_list'] = [] while True: l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size]) if val[name]['total_count'] == reads_n: pbars.update(1, progress=val[name]['total_count']) break decode_ops = [ decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op ] decoded_fname, i, predict_val, logits_prob = sess.run( decode_ops, feed_dict={training: False}) decoded_fname = np.asarray( [x.decode("UTF-8") for x in decoded_fname]) ##Have difficulties integrate it into the tensorflow graph, as the number of file names in a batch is variable. ##And for loop can't be implemented as the eager execution is disabled due to the use of queue. uniq_fname, uniq_fn_idx = np.unique(decoded_fname, return_index=True) for fn_idx, fn in enumerate(uniq_fname): i = uniq_fn_idx[fn_idx] if fn != '': occurance = np.where(decoded_fname == fn)[0] start = occurance[0] end = occurance[-1] + 1 assert (len(occurance) == end - start) if 'total_count' not in val[fn].keys(): val[fn]['total_count'] = 0 if 'index_list' not in val[fn].keys(): val[fn]['index_list'] = [] val[fn]['total_count'] += (end - start) val[fn]['index_list'].append(i) sliced_sparse = slice_ctc_decoding_result( predict_val, start, end) val[fn][i] = (sliced_sparse, logits_prob[decoded_fname == fn]) pbars.update(1, progress=val[name]['total_count']) pbars.update_bar() pbars.update(3, progress=f_i + 1) pbars.update_bar() qs_list = np.empty((0, 1), dtype=np.float) qs_string = None for i in np.sort(val[name]['index_list']): predict_val, logits_prob = val[name][i] predict_read, unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension == 'fastq': logits_prob = logits_prob[unique] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read val.pop(name) # Release the memory basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] js_ratio = FLAGS.jump / FLAGS.segment_len kernal = get_assembler_kernal(FLAGS.jump, FLAGS.segment_len) if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list, js_ratio, kernal=kernal) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads, js_ratio, kernal=kernal) c_bpread = index2base(np.argmax(consensus, axis=0)) assembly_time = time.time() - start_time list_of_time = [ start_time, reading_time, basecall_time, assembly_time ] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string, global_setting=FLAGS) pbars.end()
def read_tfrecord(data_dir, tfrecord, h5py_file_path=None, seq_length=300, k_mer=1, max_segments_num=None, skip_start = 10): ###Read from raw data count_bar = progress.multi_pbars("Extract tfrecords") if max_segments_num is None: max_segments_num = FLAGS.max_segments_number count_bar.update(0,progress = 0,total = max_segments_num) if h5py_file_path is None: h5py_file_path = tempfile.mkdtemp() + '/temp_record.hdf5' else: try: os.remove(os.path.abspath(h5py_file_path)) except: pass if not os.path.isdir(os.path.dirname(os.path.abspath(h5py_file_path))): os.mkdir(os.path.dirname(os.path.abspath(h5py_file_path))) with h5py.File(h5py_file_path, "a") as hdf5_record: event_h = hdf5_record.create_dataset('event/record', dtype='float32', shape=(0, seq_length), maxshape=(None, seq_length)) event_length_h = hdf5_record.create_dataset('event/length', dtype='int32', shape=(0,), maxshape=(None,), chunks=True) label_h = hdf5_record.create_dataset('label/record', dtype='int32', shape=(0, 0), maxshape=(None, seq_length)) label_length_h = hdf5_record.create_dataset('label/length', dtype='int32', shape=(0,), maxshape=(None,)) event = biglist(data_handle=event_h, max_len=FLAGS.MAXLEN) event_length = biglist(data_handle=event_length_h, max_len=FLAGS.MAXLEN) label = biglist(data_handle=label_h, max_len=FLAGS.MAXLEN) label_length = biglist(data_handle=label_length_h, max_len=FLAGS.MAXLEN) count = 0 file_count = 0 tfrecords_filename = data_dir + tfrecord record_iterator = tf.python_io.tf_record_iterator(path=tfrecords_filename) for string_record in record_iterator: example = tf.train.Example() example.ParseFromString(string_record) raw_data_string = (example.features.feature['raw_data'] .bytes_list .value[0]) features_string = (example.features.feature['features'] .bytes_list .value[0]) fn_string = (example.features.feature['fname'].bytes_list.value[0]) raw_data = np.frombuffer(raw_data_string, dtype=SIGNAL_DTYPE) features_data = np.frombuffer(features_string, dtype='S8') # grouping the whole array into sub-array with size = 3 group_size = 3 features_data = [features_data[n:n+group_size] for n in range(0, len(features_data), group_size)] f_signal = read_signal_tfrecord(raw_data) if len(f_signal) == 0: continue #try: f_label = read_label_tfrecord(features_data, skip_start=skip_start, window_n=(k_mer - 1) / 2) #except: # sys.stdout.write("Read the label fail.Skipped.") # continue tmp_event, tmp_event_length, tmp_label, tmp_label_length = read_raw(f_signal, f_label, seq_length) event += tmp_event event_length += tmp_event_length label += tmp_label label_length += tmp_label_length del tmp_event del tmp_event_length del tmp_label del tmp_label_length count = len(event) if file_count % 10 == 0: if max_segments_num is not None: count_bar.update(0,progress = count,total = max_segments_num) count_bar.update_bar() if len(event) > max_segments_num: event.resize(max_segments_num) label.resize(max_segments_num) event_length.resize(max_segments_num) label_length.resize(max_segments_num) break else: count_bar.update(0,progress = count,total = count) count_bar.update_bar() file_count += 1 if event.cache: train = read_cache_dataset(h5py_file_path) else: train = DataSet(event=event, event_length=event_length, label=label, label_length=label_length) count_bar.end() return train
def read_raw_data_sets(data_dir, h5py_file_path=None, seq_length=300, k_mer=1, max_segments_num=FLAGS.max_segments_number, skip_start=10): ###Read from raw data count_bar = progress.multi_pbars("Extract tfrecords") if max_segments_num is None: max_segments_num = FLAGS.max_segments_number count_bar.update(0, progress=0, total=max_segments_num) if h5py_file_path is None: h5py_file_path = tempfile.mkdtemp() + '/temp_record.hdf5' else: try: os.remove(os.path.abspath(h5py_file_path)) except: pass if not os.path.isdir(os.path.dirname(os.path.abspath(h5py_file_path))): os.mkdir(os.path.dirname(os.path.abspath(h5py_file_path))) with h5py.File(h5py_file_path, "a") as hdf5_record: event_h = hdf5_record.create_dataset('event/record', dtype='float32', shape=(0, seq_length), maxshape=(None, seq_length)) event_length_h = hdf5_record.create_dataset('event/length', dtype='int32', shape=(0, ), maxshape=(None, ), chunks=True) label_h = hdf5_record.create_dataset('label/record', dtype='int32', shape=(0, 0), maxshape=(None, seq_length)) label_length_h = hdf5_record.create_dataset('label/length', dtype='int32', shape=(0, ), maxshape=(None, )) event = biglist(data_handle=event_h, max_len=FLAGS.MAXLEN) event_length = biglist(data_handle=event_length_h, max_len=FLAGS.MAXLEN) label = biglist(data_handle=label_h, max_len=FLAGS.MAXLEN) label_length = biglist(data_handle=label_length_h, max_len=FLAGS.MAXLEN) count = 0 file_count = 0 for root, dirs, files in os.walk(data_dir, topdown=False): for name in files: if name.endswith(".signal"): file_pre = os.path.splitext(name)[0] signal_f = os.path.join(root, name) f_signal = read_signal(signal_f, normalize=FLAGS.sig_norm) label_f = os.path.join(root, file_pre + '.label') if len(f_signal) == 0: continue try: f_label = read_label(label_f, skip_start=skip_start, window_n=int((k_mer - 1) / 2)) except: sys.stdout.write("Read the label %s fail.Skipped." % (name)) continue try: tmp_event, tmp_event_length, tmp_label, tmp_label_length = read_raw( f_signal, f_label, seq_length) except Exception as e: print( "Extract label from %s fail, label position exceed max signal length." % (label_f)) raise e event += tmp_event event_length += tmp_event_length label += tmp_label label_length += tmp_label_length del tmp_event del tmp_event_length del tmp_label del tmp_label_length count = len(event) if file_count % 10 == 0: if max_segments_num is not None: count_bar.update(0, progress=count, total=max_segments_num) count_bar.update_bar() if len(event) > max_segments_num: event.resize(max_segments_num) label.resize(max_segments_num) event_length.resize(max_segments_num) label_length.resize(max_segments_num) break else: count_bar.update(0, progress=count, total=count) count_bar.update_bar() file_count += 1 if event.cache: event.save_rest() event_length.save_rest() label.save_rest() label_length.save_rest() train = read_cache_dataset(h5py_file_path) else: event.save() event_length.save() label.save() label_length.save() train = read_cache_dataset(h5py_file_path) count_bar.end() return train
def evaluation(): pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model,'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference( x, seq_length, training=training, full_sequence_len = FLAGS.segment_len, configure = model_configure) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True logits_index = tf.placeholder(tf.int32, shape=()) logits_fname = tf.placeholder(tf.string, shape=()) logits_queue = tf.FIFOQueue( capacity=1000, dtypes=[tf.float32, tf.string, tf.int32, tf.int32], shapes=[logits.shape, logits_fname.shape, logits_index.shape, seq_length.shape] ) logits_queue_size = logits_queue.size() logits_enqueue = logits_queue.enqueue((logits, logits_fname, logits_index, seq_length)) logits_queue_close = logits_queue.close() ### Decoding logits into bases decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(logits_queue) saver = tf.train.Saver() with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(config=config)) as sess: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) file_n = len(file_list) pbars.update(2,total = file_n) pbars.update(3,total = file_n) if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output, 'segments')): os.makedirs(os.path.join(FLAGS.output, 'segments')) if not os.path.exists(os.path.join(FLAGS.output, 'result')): os.makedirs(os.path.join(FLAGS.output, 'result')) if not os.path.exists(os.path.join(FLAGS.output, 'meta')): os.makedirs(os.path.join(FLAGS.output, 'meta')) def worker_fn(): for f_i, name in enumerate(file_list): if not name.endswith('.signal'): continue input_path = os.path.join(file_dir, name) eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(0,total = reads_n,progress = 0) pbars.update_bar() for i in range(0, reads_n, FLAGS.batch_size): batch_x, seq_len, _ = eval_data.next_batch( FLAGS.batch_size, shuffle=False, sig_norm=False) batch_x = np.pad( batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant') seq_len = np.pad( seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant') feed_dict = { x: batch_x, seq_length: seq_len, training: False, logits_index:i, logits_fname: name, }
def evaluation(): pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model,'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference( x, seq_length, training=training, full_sequence_len = FLAGS.segment_len, configure = model_configure) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True logits_index = tf.placeholder(tf.int32, shape=()) logits_fname = tf.placeholder(tf.string, shape=()) logits_queue = tf.FIFOQueue( capacity=1000, dtypes=[tf.float32, tf.string, tf.int32, tf.int32], shapes=[logits.shape,logits_fname.shape,logits_index.shape, seq_length.shape] ) logits_queue_size = logits_queue.size() logits_enqueue = logits_queue.enqueue((logits, logits_fname, logits_index, seq_length)) logits_queue_close = logits_queue.close() ### Decoding logits into bases decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(logits_queue) saver = tf.train.Saver() with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(config=config)) as sess: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) file_n = len(file_list) pbars.update(2,total = file_n) pbars.update(3,total = file_n) if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output, 'segments')): os.makedirs(os.path.join(FLAGS.output, 'segments')) if not os.path.exists(os.path.join(FLAGS.output, 'result')): os.makedirs(os.path.join(FLAGS.output, 'result')) if not os.path.exists(os.path.join(FLAGS.output, 'meta')): os.makedirs(os.path.join(FLAGS.output, 'meta')) def worker_fn(): for f_i, name in enumerate(file_list): if not name.endswith('.signal'): continue input_path = os.path.join(file_dir, name) eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(0,total = reads_n,progress = 0) pbars.update_bar() for i in range(0, reads_n, FLAGS.batch_size): batch_x, seq_len, _ = eval_data.next_batch( FLAGS.batch_size, shuffle=False, sig_norm=False) batch_x = np.pad( batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant') seq_len = np.pad( seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant') feed_dict = { x: batch_x, seq_length: np.round(seq_len/ratio).astype(np.int32), training: False, logits_index:i, logits_fname: name, } sess.run(logits_enqueue,feed_dict=feed_dict) pbars.update(0,progress=i+FLAGS.batch_size) pbars.update_bar() pbars.update(2,progress = f_i+1) pbars.update_bar() sess.run(logits_queue_close) worker = threading.Thread(target=worker_fn,args=() ) worker.setDaemon(True) worker.start() val = defaultdict(dict) # We could read vals out of order, that's why it's a dict for f_i, name in enumerate(file_list): start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(file_dir, name) if FLAGS.mode == 'rna': eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) else: eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(1,total = reads_n,progress = 0) pbars.update_bar() reading_time = time.time() - start_time reads = list() N = len(range(0, reads_n, FLAGS.batch_size)) while True: l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size]) decode_ops = [decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op] decoded_fname, i, predict_val, logits_prob = sess.run(decode_ops, feed_dict={training: False}) decoded_fname = decoded_fname.decode("UTF-8") val[decoded_fname][i] = (predict_val, logits_prob) pbars.update(1,progress = len(val[name])*FLAGS.batch_size) pbars.update_bar() if len(val[name]) == N: break pbars.update(3,progress = f_i+1) pbars.update_bar() qs_list = np.empty((0, 1), dtype=np.float) qs_string = None for i in range(0, reads_n, FLAGS.batch_size): predict_val, logits_prob = val[name][i] predict_read, unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension == 'fastq': logits_prob = logits_prob[unique] if i + FLAGS.batch_size > reads_n: predict_read = predict_read[:reads_n - i] if FLAGS.extension == 'fastq': logits_prob = logits_prob[:reads_n - i] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read val.pop(name) # Release the memory basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] js_ratio = FLAGS.jump/FLAGS.segment_len if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list,js_ratio) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads,js_ratio) c_bpread = index2base(np.argmax(consensus, axis=0)) assembly_time = time.time() - start_time list_of_time = [start_time, reading_time, basecall_time, assembly_time] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string,global_setting=FLAGS) pbars.end()
def do_inference(): """Tests PredictionService with concurrent requests. Raises: IOError: An error occurred processing test data set. """ if FLAGS.mode == 'dna': CONF = DNA_CONF() elif FLAGS.mode == 'rna': CONF = RNA_CONF() else: raise ValueError("Mode has to be either rna or dna.") make_dirs(FLAGS.output) FLAGS.segment_len = CONF.SEGMENT_LEN FLAGS.jump = CONF.JUMP FLAGS.start = CONF.START pbars = multi_pbars(["Request Submit:", "Request finished"]) channel = grpc.insecure_channel(FLAGS.server) stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) request = predict_pb2.PredictRequest() request.model_spec.name = 'chiron' request.model_spec.signature_name = 'predicted_sequences' collector = _Result_Collection(concurrency=FLAGS.concurrency) file_list = gen_file_list(FLAGS.input) batch_iterator = data_iterator(file_list) def submit_fn(): for batch_x, seq_len, i, f, N, reads_n in batch_iterator: seq_len = np.reshape(seq_len, (seq_len.shape[0], 1)) combined_input = np.concatenate((batch_x, seq_len), axis=1).astype(np.float32) request.inputs['combined_inputs'].CopyFrom( tf.contrib.util.make_tensor_proto( combined_input, shape=[FLAGS.batch_size, CONF.SEGMENT_LEN + 1])) collector.throttle() result_future = stub.Predict.future(request, 100.0) # 5 seconds result_future.add_done_callback(_post_process(collector, i, f, N)) pbars.update(0, total=reads_n, progress=(i + 1) * FLAGS.batch_size) pbars.update_bar() submiter = threading.Thread(target=submit_fn, args=()) submiter.setDaemon(True) submiter.start() pbars.update(1, total=len(file_list)) pbars.update_bar() while not collector.all_done(): if len(collector._done) > 0: qs_string = None f_p = collector._done[0] reads, probs = collector.pop_out(f_p) bpreads = [index2base(read) for read in reads] consensus, qs_consensus = simple_assembly_qs(bpreads, probs) qs_string = qs(consensus, qs_consensus) c_bpread = index2base(np.argmax(consensus, axis=0)) file_pre = os.path.basename(os.path.splitext(f_p)[0]) write_output(bpreads, c_bpread, [np.NaN] * 4, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string, global_setting=FLAGS) pbars.update(1, progress=pbars.progress[1] + 1) pbars.update_bar()
default=30, help= "Beam width used in beam search decoder, default is 0, in which a greedy decoder is used. Recommend width:100, Large beam width give better decoding result but require longer decoding time." ) parser.add_argument( '--concise', action='store_true', help= "Concisely output the result, the meta and segments files will not be output." ) parser.add_argument('--mode', default='dna', help="Output mode, can be chosen from dna or rna.") FLAGS = parser.parse_args(sys.argv[1:]) pbars = multi_pbars( ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.model, 'model.json') model_configure = chiron_model.read_config(config_path) logits, ratio = chiron_model.inference(x, seq_length, training=training, full_sequence_len=FLAGS.segment_len, configure=model_configure) prorbs = tf.nn.softmax(logits) predict = tf.nn.ctc_beam_search_decoder(tf.transpose(logits, perm=[1, 0, 2]), seq_length, merge_repeated=False,