Ejemplo n.º 1
0
def evaluation():
    x = tf.placeholder(tf.float32,shape = [FLAGS.batch_size,FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape = [FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    logits,_ = inference(x,seq_length,training = training)
    if FLAGS.extension =='fastq':
        prob = path_prob(logits)
    predict = tf.nn.ctc_greedy_decoder(tf.transpose(logits,perm=[1,0,2]),seq_length,merge_repeated = True)
#    predict = tf.nn.ctc_beam_search_decoder(tf.transpose(logits,perm=[1,0,2]),seq_length,merge_repeated = False)#For beam_search_decoder, set the merge_repeated to false. 5-10 times slower than greedy decoder
    config=tf.ConfigProto(allow_soft_placement=True,intra_op_parallelism_threads=FLAGS.threads,inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    with tf.Session(config = config) as sess:
        saver = tf.train.Saver()
        saver.restore(sess,tf.train.latest_checkpoint(FLAGS.model))
        if os.path.isdir(FLAGS.input):
            file_list = os.listdir(FLAGS.input)
            file_dir = FLAGS.input
        else:
            file_list = [os.path.basename(FLAGS.input)]
            file_dir = os.path.abspath(os.path.join(FLAGS.input,os.path.pardir))
        #Make output folder.
        if not os.path.exists(FLAGS.output):
            os.makedirs(FLAGS.output)
        if not os.path.exists(os.path.join(FLAGS.output,'segments')):
            os.makedirs(os.path.join(FLAGS.output,'segments'))
        if not os.path.exists(os.path.join(FLAGS.output,'result')):
            os.makedirs(os.path.join(FLAGS.output,'result'))
        if not os.path.exists(os.path.join(FLAGS.output,'meta')):
            os.makedirs(os.path.join(FLAGS.output,'meta'))

        for name in file_list:
            start_time = time.time()
            if not name.endswith('.signal'):
                continue
            file_pre = os.path.splitext(name)[0]
            input_path = os.path.join(file_dir,name)
            eval_data = read_data_for_eval(input_path,FLAGS.start,FLAGS.segment_len,FLAGS.jump,FLAGS.smooth_window,FLAGS.skip_step,FLAGS.normalize)
            reads_n = eval_data.reads_n
            reading_time=time.time()-start_time
            reads = list()
            qs_list = np.empty((0,1),dtype = np.float)
            qs_string = None
            for i in range(0,reads_n,FLAGS.batch_size):
                batch_x,seq_len,_ = eval_data.next_batch(FLAGS.batch_size,shuffle = False)
                batch_x=np.pad(batch_x,((0,FLAGS.batch_size-len(batch_x)),(0,0)),mode='constant')
                seq_len=np.pad(seq_len,((0,FLAGS.batch_size-len(seq_len))),mode='constant')
                feed_dict = {x:batch_x,seq_length:seq_len,training:False}
                if FLAGS.extension=='fastq':
                    predict_val,logits_prob= sess.run([predict,prob],feed_dict = feed_dict)
                else:
                    predict_val= sess.run(predict,feed_dict = feed_dict)
                predict_read,unique = sparse2dense(predict_val)
                predict_read = predict_read[0]
                unique = unique[0]

                if FLAGS.extension=='fastq':
                    logits_prob = logits_prob[unique]
                if i+FLAGS.batch_size>reads_n:
                    predict_read = predict_read[:reads_n-i]
                    if FLAGS.extension == 'fastq':
                        logits_prob = logits_prob[:reads_n-i]
                if FLAGS.extension == 'fastq':
                    qs_list = np.concatenate((qs_list,logits_prob))
                reads+=predict_read
            print("Segment reads base calling finished, begin to assembly. %5.2f seconds"%(time.time()-start_time))
            basecall_time=time.time()-start_time
            bpreads = [index2base(read) for read in reads]
            if FLAGS.extension == 'fastq':
                consensus,qs_consensus = simple_assembly_qs(bpreads,qs_list,FLAGS.alphabet)
                qs_string = qs(consensus,qs_consensus)
            else:
                consensus = simple_assembly(bpreads,FLAGS.alphabet)
            c_bpread = index2base(np.argmax(consensus,axis = 0))
            np.set_printoptions(threshold=np.nan)
            assembly_time=time.time()-start_time
            print("Assembly finished, begin output. %5.2f seconds"%(time.time()-start_time))
            list_of_time = [start_time,reading_time,basecall_time,assembly_time]
            write_output(bpreads,c_bpread,list_of_time,file_pre,suffix = FLAGS.extension,q_score = qs_string)
Ejemplo n.º 2
0
def evaluation():
    pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"])
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.config_path,'model.json')
    model_configure = read_config(config_path)

    logits, ratio = inference(
                                    x, 
                                    seq_length, 
                                    training=training,
                                    full_sequence_len = FLAGS.segment_len,
                                    configure = model_configure)
    config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    logits_index = tf.placeholder(tf.int32, shape=())
    logits_fname = tf.placeholder(tf.string, shape=())
    logits_queue = tf.FIFOQueue(
        capacity=1000,
        dtypes=[tf.float32, tf.string, tf.int32, tf.int32],
        shapes=[logits.shape,logits_fname.shape,logits_index.shape, seq_length.shape]
    )
    logits_queue_size = logits_queue.size()
    logits_enqueue = logits_queue.enqueue((logits, logits_fname, logits_index, seq_length))
    logits_queue_close = logits_queue.close()
    ### Decoding logits into bases
    decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(logits_queue)
    saver = tf.train.Saver()
    with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(config=config)) as sess:
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model))
        if os.path.isdir(FLAGS.input):
            file_list = os.listdir(FLAGS.input)
            file_dir = FLAGS.input
        else:
            file_list = [os.path.basename(FLAGS.input)]
            file_dir = os.path.abspath(
                os.path.join(FLAGS.input, os.path.pardir))
        file_n = len(file_list)
        pbars.update(2,total = file_n)
        pbars.update(3,total = file_n)
        if not os.path.exists(FLAGS.output):
            os.makedirs(FLAGS.output)
        if not os.path.exists(os.path.join(FLAGS.output, 'segments')):
            os.makedirs(os.path.join(FLAGS.output, 'segments'))
        if not os.path.exists(os.path.join(FLAGS.output, 'result')):
            os.makedirs(os.path.join(FLAGS.output, 'result'))
        if not os.path.exists(os.path.join(FLAGS.output, 'meta')):
            os.makedirs(os.path.join(FLAGS.output, 'meta'))
        def worker_fn():
            for f_i, name in enumerate(file_list):
                if not name.endswith('.signal'):
                    continue
                input_path = os.path.join(file_dir, name)
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
                reads_n = eval_data.reads_n
                pbars.update(0,total = reads_n,progress = 0)
                pbars.update_bar()
                for i in range(0, reads_n, FLAGS.batch_size):
                    batch_x, seq_len, _ = eval_data.next_batch(
                        FLAGS.batch_size, shuffle=False, sig_norm=False)
                    batch_x = np.pad(
                        batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant')
                    seq_len = np.pad(
                        seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant')
                    feed_dict = {
                        x: batch_x,
                        seq_length: np.round(seq_len/ratio).astype(np.int32),
                        training: False,
                        logits_index:i,
                        logits_fname: name,
                    }
                    sess.run(logits_enqueue,feed_dict=feed_dict)
                    pbars.update(0,progress=i+FLAGS.batch_size)
                    pbars.update_bar()
                pbars.update(2,progress = f_i+1)
                pbars.update_bar()
            sess.run(logits_queue_close)

        worker = threading.Thread(target=worker_fn,args=() )
        worker.setDaemon(True)
        worker.start()

        val = defaultdict(dict)  # We could read vals out of order, that's why it's a dict
        for f_i, name in enumerate(file_list):
            start_time = time.time()
            if not name.endswith('.signal'):
                continue
            file_pre = os.path.splitext(name)[0]
            input_path = os.path.join(file_dir, name)
            if FLAGS.mode == 'rna':
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                           seg_length=FLAGS.segment_len,
                                           step=FLAGS.jump)
            else:
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                           seg_length=FLAGS.segment_len,
                                           step=FLAGS.jump)
            reads_n = eval_data.reads_n
            pbars.update(1,total = reads_n,progress = 0)
            pbars.update_bar()
            reading_time = time.time() - start_time
            reads = list()

            N = len(range(0, reads_n, FLAGS.batch_size))
            while True:
                l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size])
                decode_ops = [decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op]
                decoded_fname, i, predict_val, logits_prob = sess.run(decode_ops, feed_dict={training: False})
                decoded_fname = decoded_fname.decode("UTF-8")
                val[decoded_fname][i] = (predict_val, logits_prob)               
                pbars.update(1,progress = len(val[name])*FLAGS.batch_size)
                pbars.update_bar()
                if len(val[name]) == N:
                    break

            pbars.update(3,progress = f_i+1)
            pbars.update_bar()
            qs_list = np.empty((0, 1), dtype=np.float)
            qs_string = None
            for i in range(0, reads_n, FLAGS.batch_size):
                predict_val, logits_prob = val[name][i]
                predict_read, unique = sparse2dense(predict_val)
                predict_read = predict_read[0]
                unique = unique[0]

                if FLAGS.extension == 'fastq':
                    logits_prob = logits_prob[unique]
                if i + FLAGS.batch_size > reads_n:
                    predict_read = predict_read[:reads_n - i]
                    if FLAGS.extension == 'fastq':
                        logits_prob = logits_prob[:reads_n - i]
                if FLAGS.extension == 'fastq':
                    qs_list = np.concatenate((qs_list, logits_prob))
                reads += predict_read
            val.pop(name)  # Release the memory

            basecall_time = time.time() - start_time
            bpreads = [index2base(read) for read in reads]
            if FLAGS.extension == 'fastq':
                consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list)
                qs_string = qs(consensus, qs_consensus)
            else:
                consensus = simple_assembly(bpreads)
            c_bpread = index2base(np.argmax(consensus, axis=0))
            assembly_time = time.time() - start_time
            list_of_time = [start_time, reading_time,
                            basecall_time, assembly_time]
            write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension,
                         q_score=qs_string,global_setting=FLAGS)
    pbars.end()