def evaluation(): x = tf.placeholder(tf.float32,shape = [FLAGS.batch_size,FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape = [FLAGS.batch_size]) training = tf.placeholder(tf.bool) logits,_ = inference(x,seq_length,training = training) if FLAGS.extension =='fastq': prob = path_prob(logits) predict = tf.nn.ctc_greedy_decoder(tf.transpose(logits,perm=[1,0,2]),seq_length,merge_repeated = True) # predict = tf.nn.ctc_beam_search_decoder(tf.transpose(logits,perm=[1,0,2]),seq_length,merge_repeated = False)#For beam_search_decoder, set the merge_repeated to false. 5-10 times slower than greedy decoder config=tf.ConfigProto(allow_soft_placement=True,intra_op_parallelism_threads=FLAGS.threads,inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True with tf.Session(config = config) as sess: saver = tf.train.Saver() saver.restore(sess,tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath(os.path.join(FLAGS.input,os.path.pardir)) #Make output folder. if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output,'segments')): os.makedirs(os.path.join(FLAGS.output,'segments')) if not os.path.exists(os.path.join(FLAGS.output,'result')): os.makedirs(os.path.join(FLAGS.output,'result')) if not os.path.exists(os.path.join(FLAGS.output,'meta')): os.makedirs(os.path.join(FLAGS.output,'meta')) for name in file_list: start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(file_dir,name) eval_data = read_data_for_eval(input_path,FLAGS.start,FLAGS.segment_len,FLAGS.jump,FLAGS.smooth_window,FLAGS.skip_step,FLAGS.normalize) reads_n = eval_data.reads_n reading_time=time.time()-start_time reads = list() qs_list = np.empty((0,1),dtype = np.float) qs_string = None for i in range(0,reads_n,FLAGS.batch_size): batch_x,seq_len,_ = eval_data.next_batch(FLAGS.batch_size,shuffle = False) batch_x=np.pad(batch_x,((0,FLAGS.batch_size-len(batch_x)),(0,0)),mode='constant') seq_len=np.pad(seq_len,((0,FLAGS.batch_size-len(seq_len))),mode='constant') feed_dict = {x:batch_x,seq_length:seq_len,training:False} if FLAGS.extension=='fastq': predict_val,logits_prob= sess.run([predict,prob],feed_dict = feed_dict) else: predict_val= sess.run(predict,feed_dict = feed_dict) predict_read,unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension=='fastq': logits_prob = logits_prob[unique] if i+FLAGS.batch_size>reads_n: predict_read = predict_read[:reads_n-i] if FLAGS.extension == 'fastq': logits_prob = logits_prob[:reads_n-i] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list,logits_prob)) reads+=predict_read print("Segment reads base calling finished, begin to assembly. %5.2f seconds"%(time.time()-start_time)) basecall_time=time.time()-start_time bpreads = [index2base(read) for read in reads] if FLAGS.extension == 'fastq': consensus,qs_consensus = simple_assembly_qs(bpreads,qs_list,FLAGS.alphabet) qs_string = qs(consensus,qs_consensus) else: consensus = simple_assembly(bpreads,FLAGS.alphabet) c_bpread = index2base(np.argmax(consensus,axis = 0)) np.set_printoptions(threshold=np.nan) assembly_time=time.time()-start_time print("Assembly finished, begin output. %5.2f seconds"%(time.time()-start_time)) list_of_time = [start_time,reading_time,basecall_time,assembly_time] write_output(bpreads,c_bpread,list_of_time,file_pre,suffix = FLAGS.extension,q_score = qs_string)
def evaluation(): pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"]) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) config_path = os.path.join(FLAGS.config_path,'model.json') model_configure = read_config(config_path) logits, ratio = inference( x, seq_length, training=training, full_sequence_len = FLAGS.segment_len, configure = model_configure) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True logits_index = tf.placeholder(tf.int32, shape=()) logits_fname = tf.placeholder(tf.string, shape=()) logits_queue = tf.FIFOQueue( capacity=1000, dtypes=[tf.float32, tf.string, tf.int32, tf.int32], shapes=[logits.shape,logits_fname.shape,logits_index.shape, seq_length.shape] ) logits_queue_size = logits_queue.size() logits_enqueue = logits_queue.enqueue((logits, logits_fname, logits_index, seq_length)) logits_queue_close = logits_queue.close() ### Decoding logits into bases decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(logits_queue) saver = tf.train.Saver() with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(config=config)) as sess: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) file_n = len(file_list) pbars.update(2,total = file_n) pbars.update(3,total = file_n) if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(os.path.join(FLAGS.output, 'segments')): os.makedirs(os.path.join(FLAGS.output, 'segments')) if not os.path.exists(os.path.join(FLAGS.output, 'result')): os.makedirs(os.path.join(FLAGS.output, 'result')) if not os.path.exists(os.path.join(FLAGS.output, 'meta')): os.makedirs(os.path.join(FLAGS.output, 'meta')) def worker_fn(): for f_i, name in enumerate(file_list): if not name.endswith('.signal'): continue input_path = os.path.join(file_dir, name) eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(0,total = reads_n,progress = 0) pbars.update_bar() for i in range(0, reads_n, FLAGS.batch_size): batch_x, seq_len, _ = eval_data.next_batch( FLAGS.batch_size, shuffle=False, sig_norm=False) batch_x = np.pad( batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant') seq_len = np.pad( seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant') feed_dict = { x: batch_x, seq_length: np.round(seq_len/ratio).astype(np.int32), training: False, logits_index:i, logits_fname: name, } sess.run(logits_enqueue,feed_dict=feed_dict) pbars.update(0,progress=i+FLAGS.batch_size) pbars.update_bar() pbars.update(2,progress = f_i+1) pbars.update_bar() sess.run(logits_queue_close) worker = threading.Thread(target=worker_fn,args=() ) worker.setDaemon(True) worker.start() val = defaultdict(dict) # We could read vals out of order, that's why it's a dict for f_i, name in enumerate(file_list): start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(file_dir, name) if FLAGS.mode == 'rna': eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) else: eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n pbars.update(1,total = reads_n,progress = 0) pbars.update_bar() reading_time = time.time() - start_time reads = list() N = len(range(0, reads_n, FLAGS.batch_size)) while True: l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size]) decode_ops = [decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op] decoded_fname, i, predict_val, logits_prob = sess.run(decode_ops, feed_dict={training: False}) decoded_fname = decoded_fname.decode("UTF-8") val[decoded_fname][i] = (predict_val, logits_prob) pbars.update(1,progress = len(val[name])*FLAGS.batch_size) pbars.update_bar() if len(val[name]) == N: break pbars.update(3,progress = f_i+1) pbars.update_bar() qs_list = np.empty((0, 1), dtype=np.float) qs_string = None for i in range(0, reads_n, FLAGS.batch_size): predict_val, logits_prob = val[name][i] predict_read, unique = sparse2dense(predict_val) predict_read = predict_read[0] unique = unique[0] if FLAGS.extension == 'fastq': logits_prob = logits_prob[unique] if i + FLAGS.batch_size > reads_n: predict_read = predict_read[:reads_n - i] if FLAGS.extension == 'fastq': logits_prob = logits_prob[:reads_n - i] if FLAGS.extension == 'fastq': qs_list = np.concatenate((qs_list, logits_prob)) reads += predict_read val.pop(name) # Release the memory basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] if FLAGS.extension == 'fastq': consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list) qs_string = qs(consensus, qs_consensus) else: consensus = simple_assembly(bpreads) c_bpread = index2base(np.argmax(consensus, axis=0)) assembly_time = time.time() - start_time list_of_time = [start_time, reading_time, basecall_time, assembly_time] write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension, q_score=qs_string,global_setting=FLAGS) pbars.end()
def evaluation(): x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) training = tf.placeholder(tf.bool) logits, _ = inference(x, seq_length, training=training) predict = tf.nn.ctc_greedy_decoder(tf.transpose(logits, perm=[1, 0, 2]), seq_length, merge_repeated=True) config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model)) if os.path.isdir(FLAGS.input): file_list = os.listdir(FLAGS.input) file_dir = FLAGS.input else: file_list = [os.path.basename(FLAGS.input)] file_dir = os.path.abspath( os.path.join(FLAGS.input, os.path.pardir)) for name in file_list: start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] input_path = os.path.join(file_dir, name) eval_data = read_data_for_eval(input_path, FLAGS.start, seg_length=FLAGS.segment_len, step=FLAGS.jump) reads_n = eval_data.reads_n reading_time = time.time() - start_time reads = list() for i in range(0, reads_n, FLAGS.batch_size): batch_x, seq_len, _ = eval_data.next_batch(FLAGS.batch_size, shuffle=False) batch_x = np.pad(batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant') seq_len = np.pad(seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant') feed_dict = {x: batch_x, seq_length: seq_len, training: False} predict_val = sess.run(predict, feed_dict=feed_dict) predict_read = sparse2dense(predict_val)[0] if i + FLAGS.batch_size > reads_n: predict_read = predict_read[:reads_n - i] reads += predict_read print( "Segment reads base calling finished, begin to assembly. %5.2f seconds" % (time.time() - start_time)) basecall_time = time.time() - start_time bpreads = [index2base(read) for read in reads] concensus = simple_assembly(bpreads) c_bpread = index2base(np.argmax(concensus, axis=0)) assembly_time = time.time() - start_time print("Assembly finished, begin output. %5.2f seconds" % (time.time() - start_time)) result_folder = os.path.join(FLAGS.output, 'result') seg_folder = os.path.join(FLAGS.output, 'segments') meta_folder = os.path.join(FLAGS.output, 'meta') if not os.path.exists(FLAGS.output): os.makedirs(FLAGS.output) if not os.path.exists(seg_folder): os.makedirs(seg_folder) if not os.path.exists(result_folder): os.makedirs(result_folder) if not os.path.exists(meta_folder): os.makedirs(meta_folder) path_con = os.path.join(result_folder, file_pre + '.fasta') path_reads = os.path.join(seg_folder, file_pre + '.fasta') path_meta = os.path.join(meta_folder, file_pre + '.meta') with open(path_reads, 'w+') as out_f, open(path_con, 'w+') as out_con: for indx, read in enumerate(bpreads): out_f.write(file_pre + str(indx) + '\n') out_f.write(read + '\n') out_con.write("{}\n{}".format(file_pre, c_bpread)) with open(path_meta, 'w+') as out_meta: total_time = time.time() - start_time output_time = total_time - assembly_time assembly_time -= basecall_time basecall_time -= reading_time total_len = len(c_bpread) total_time = time.time() - start_time out_meta.write( "# Reading Basecalling assembly output total rate(bp/s)\n") out_meta.write( "%5.3f %5.3f %5.3f %5.3f %5.3f %5.3f\n" % (reading_time, basecall_time, assembly_time, output_time, total_time, total_len / total_time)) out_meta.write( "# read_len batch_size segment_len jump start_pos\n") out_meta.write("%d %d %d %d %d\n" % (total_len, FLAGS.batch_size, FLAGS.segment_len, FLAGS.jump, FLAGS.start)) out_meta.write("# input_name model_name\n") out_meta.write("%s %s\n" % (FLAGS.input, FLAGS.model))
def evaluation(signal_input, args): print("@ Loading U-net model ...") if args.model_param != "": params = load_modelParam(args.model_param) else: print("! Unable to load the model parameters, pls check!") exit() model_name = get_unet_model_name(params, args) unet_model = models.load_model("./experiment/model/weights/" + model_name+ ".h5", \ custom_objects={'dice_coef_loss':dice_coef_loss, 'dice_coef':dice_coef, 'bce_dice_loss': bce_dice_loss, \ 'categorical_focal_loss_fixed':categorical_focal_loss(gamma=2., alpha=.25), \ 'ce_dice_loss': ce_dice_loss}) if args.norm != "": print("@ Perform data normalization ... ") print("- Loading form %s" %(args.norm)) pickle_in = open(args.norm,"rb") stat_dict = pickle.load(pickle_in) print("- Training Data statistics m=%f, s=%f" %(stat_dict["m"], stat_dict["s"])) ############################################################################# print("@ loading signal files ...") if os.path.isdir(signal_input): file_list = os.listdir(signal_input) file_dir = signal_input else: file_list = [os.path.basename(signal_input)] file_dir = os.path.abspath(os.path.join(signal_input, os.path.pardir)) ## make the subfold if not os.path.exists(args.output): os.makedirs(args.output) if not os.path.exists(os.path.join(args.output, 'segments')): os.makedirs(os.path.join(args.output, 'segments')) if not os.path.exists(os.path.join(args.output, 'result')): os.makedirs(os.path.join(args.output, 'result')) if not os.path.exists(os.path.join(args.output, 'meta')): os.makedirs(os.path.join(args.output, 'meta')) # start processing files in the file list for name in file_list: start_time = time.time() if not name.endswith('.signal'): continue file_pre = os.path.splitext(name)[0] print("- Processing read %s" %(file_pre)) input_path = os.path.join(file_dir, name) # reading files, take care about normalization issue. @@ chekcing the data normalization issues eval_data = read_data_for_eval(input_path, args.start, args.jump, args.segment_len) reads_n = eval_data.reads_n reading_time = time.time() - start_time reads = list() signals = np.empty((0, args.segment_len), dtype=np.float) # doing the base-calling for the loaded signals for i in range(0, reads_n, args.batch_size): # get input signals X, seq_len, _, _, _, _ = eval_data.next_batch(args.batch_size, shuffle=False) # call different basecallers here. X = X.reshape(X.shape[0], X.shape[1], 1).astype("float32") # normalization of the data if args.norm != "": X = (X - stat_dict["m"])/(stat_dict["s"]) output = unet_basecaller(unet_model, X) reads += output # assembly the read results print("Segment reads base calling finished, begin to assembly. %5.2f seconds" % (time.time() - start_time)) basecall_time = time.time() - start_time # doing simple assembly methods #print(reads) consensus = simple_assembly(reads) c_bpread = index2base_0(np.argmax(consensus, axis=0)) assembly_time = time.time() - start_time print("Assembly finished, begin output. %5.2f seconds" % (time.time() - start_time)) # writing the files to the fold list_of_time = [start_time, reading_time, basecall_time, assembly_time] write_output(args, reads, "".join(c_bpread), list_of_time, file_pre)