示例#1
0
def evaluation():
    x = tf.placeholder(tf.float32,shape = [FLAGS.batch_size,FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape = [FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    logits,_ = inference(x,seq_length,training = training)
    if FLAGS.extension =='fastq':
        prob = path_prob(logits)
    predict = tf.nn.ctc_greedy_decoder(tf.transpose(logits,perm=[1,0,2]),seq_length,merge_repeated = True)
#    predict = tf.nn.ctc_beam_search_decoder(tf.transpose(logits,perm=[1,0,2]),seq_length,merge_repeated = False)#For beam_search_decoder, set the merge_repeated to false. 5-10 times slower than greedy decoder
    config=tf.ConfigProto(allow_soft_placement=True,intra_op_parallelism_threads=FLAGS.threads,inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    with tf.Session(config = config) as sess:
        saver = tf.train.Saver()
        saver.restore(sess,tf.train.latest_checkpoint(FLAGS.model))
        if os.path.isdir(FLAGS.input):
            file_list = os.listdir(FLAGS.input)
            file_dir = FLAGS.input
        else:
            file_list = [os.path.basename(FLAGS.input)]
            file_dir = os.path.abspath(os.path.join(FLAGS.input,os.path.pardir))
        #Make output folder.
        if not os.path.exists(FLAGS.output):
            os.makedirs(FLAGS.output)
        if not os.path.exists(os.path.join(FLAGS.output,'segments')):
            os.makedirs(os.path.join(FLAGS.output,'segments'))
        if not os.path.exists(os.path.join(FLAGS.output,'result')):
            os.makedirs(os.path.join(FLAGS.output,'result'))
        if not os.path.exists(os.path.join(FLAGS.output,'meta')):
            os.makedirs(os.path.join(FLAGS.output,'meta'))

        for name in file_list:
            start_time = time.time()
            if not name.endswith('.signal'):
                continue
            file_pre = os.path.splitext(name)[0]
            input_path = os.path.join(file_dir,name)
            eval_data = read_data_for_eval(input_path,FLAGS.start,FLAGS.segment_len,FLAGS.jump,FLAGS.smooth_window,FLAGS.skip_step,FLAGS.normalize)
            reads_n = eval_data.reads_n
            reading_time=time.time()-start_time
            reads = list()
            qs_list = np.empty((0,1),dtype = np.float)
            qs_string = None
            for i in range(0,reads_n,FLAGS.batch_size):
                batch_x,seq_len,_ = eval_data.next_batch(FLAGS.batch_size,shuffle = False)
                batch_x=np.pad(batch_x,((0,FLAGS.batch_size-len(batch_x)),(0,0)),mode='constant')
                seq_len=np.pad(seq_len,((0,FLAGS.batch_size-len(seq_len))),mode='constant')
                feed_dict = {x:batch_x,seq_length:seq_len,training:False}
                if FLAGS.extension=='fastq':
                    predict_val,logits_prob= sess.run([predict,prob],feed_dict = feed_dict)
                else:
                    predict_val= sess.run(predict,feed_dict = feed_dict)
                predict_read,unique = sparse2dense(predict_val)
                predict_read = predict_read[0]
                unique = unique[0]

                if FLAGS.extension=='fastq':
                    logits_prob = logits_prob[unique]
                if i+FLAGS.batch_size>reads_n:
                    predict_read = predict_read[:reads_n-i]
                    if FLAGS.extension == 'fastq':
                        logits_prob = logits_prob[:reads_n-i]
                if FLAGS.extension == 'fastq':
                    qs_list = np.concatenate((qs_list,logits_prob))
                reads+=predict_read
            print("Segment reads base calling finished, begin to assembly. %5.2f seconds"%(time.time()-start_time))
            basecall_time=time.time()-start_time
            bpreads = [index2base(read) for read in reads]
            if FLAGS.extension == 'fastq':
                consensus,qs_consensus = simple_assembly_qs(bpreads,qs_list,FLAGS.alphabet)
                qs_string = qs(consensus,qs_consensus)
            else:
                consensus = simple_assembly(bpreads,FLAGS.alphabet)
            c_bpread = index2base(np.argmax(consensus,axis = 0))
            np.set_printoptions(threshold=np.nan)
            assembly_time=time.time()-start_time
            print("Assembly finished, begin output. %5.2f seconds"%(time.time()-start_time))
            list_of_time = [start_time,reading_time,basecall_time,assembly_time]
            write_output(bpreads,c_bpread,list_of_time,file_pre,suffix = FLAGS.extension,q_score = qs_string)
示例#2
0
def evaluation():
    pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"])
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.config_path,'model.json')
    model_configure = read_config(config_path)

    logits, ratio = inference(
                                    x, 
                                    seq_length, 
                                    training=training,
                                    full_sequence_len = FLAGS.segment_len,
                                    configure = model_configure)
    config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    logits_index = tf.placeholder(tf.int32, shape=())
    logits_fname = tf.placeholder(tf.string, shape=())
    logits_queue = tf.FIFOQueue(
        capacity=1000,
        dtypes=[tf.float32, tf.string, tf.int32, tf.int32],
        shapes=[logits.shape,logits_fname.shape,logits_index.shape, seq_length.shape]
    )
    logits_queue_size = logits_queue.size()
    logits_enqueue = logits_queue.enqueue((logits, logits_fname, logits_index, seq_length))
    logits_queue_close = logits_queue.close()
    ### Decoding logits into bases
    decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(logits_queue)
    saver = tf.train.Saver()
    with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(config=config)) as sess:
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model))
        if os.path.isdir(FLAGS.input):
            file_list = os.listdir(FLAGS.input)
            file_dir = FLAGS.input
        else:
            file_list = [os.path.basename(FLAGS.input)]
            file_dir = os.path.abspath(
                os.path.join(FLAGS.input, os.path.pardir))
        file_n = len(file_list)
        pbars.update(2,total = file_n)
        pbars.update(3,total = file_n)
        if not os.path.exists(FLAGS.output):
            os.makedirs(FLAGS.output)
        if not os.path.exists(os.path.join(FLAGS.output, 'segments')):
            os.makedirs(os.path.join(FLAGS.output, 'segments'))
        if not os.path.exists(os.path.join(FLAGS.output, 'result')):
            os.makedirs(os.path.join(FLAGS.output, 'result'))
        if not os.path.exists(os.path.join(FLAGS.output, 'meta')):
            os.makedirs(os.path.join(FLAGS.output, 'meta'))
        def worker_fn():
            for f_i, name in enumerate(file_list):
                if not name.endswith('.signal'):
                    continue
                input_path = os.path.join(file_dir, name)
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
                reads_n = eval_data.reads_n
                pbars.update(0,total = reads_n,progress = 0)
                pbars.update_bar()
                for i in range(0, reads_n, FLAGS.batch_size):
                    batch_x, seq_len, _ = eval_data.next_batch(
                        FLAGS.batch_size, shuffle=False, sig_norm=False)
                    batch_x = np.pad(
                        batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant')
                    seq_len = np.pad(
                        seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant')
                    feed_dict = {
                        x: batch_x,
                        seq_length: np.round(seq_len/ratio).astype(np.int32),
                        training: False,
                        logits_index:i,
                        logits_fname: name,
                    }
                    sess.run(logits_enqueue,feed_dict=feed_dict)
                    pbars.update(0,progress=i+FLAGS.batch_size)
                    pbars.update_bar()
                pbars.update(2,progress = f_i+1)
                pbars.update_bar()
            sess.run(logits_queue_close)

        worker = threading.Thread(target=worker_fn,args=() )
        worker.setDaemon(True)
        worker.start()

        val = defaultdict(dict)  # We could read vals out of order, that's why it's a dict
        for f_i, name in enumerate(file_list):
            start_time = time.time()
            if not name.endswith('.signal'):
                continue
            file_pre = os.path.splitext(name)[0]
            input_path = os.path.join(file_dir, name)
            if FLAGS.mode == 'rna':
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                           seg_length=FLAGS.segment_len,
                                           step=FLAGS.jump)
            else:
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                           seg_length=FLAGS.segment_len,
                                           step=FLAGS.jump)
            reads_n = eval_data.reads_n
            pbars.update(1,total = reads_n,progress = 0)
            pbars.update_bar()
            reading_time = time.time() - start_time
            reads = list()

            N = len(range(0, reads_n, FLAGS.batch_size))
            while True:
                l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size])
                decode_ops = [decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op]
                decoded_fname, i, predict_val, logits_prob = sess.run(decode_ops, feed_dict={training: False})
                decoded_fname = decoded_fname.decode("UTF-8")
                val[decoded_fname][i] = (predict_val, logits_prob)               
                pbars.update(1,progress = len(val[name])*FLAGS.batch_size)
                pbars.update_bar()
                if len(val[name]) == N:
                    break

            pbars.update(3,progress = f_i+1)
            pbars.update_bar()
            qs_list = np.empty((0, 1), dtype=np.float)
            qs_string = None
            for i in range(0, reads_n, FLAGS.batch_size):
                predict_val, logits_prob = val[name][i]
                predict_read, unique = sparse2dense(predict_val)
                predict_read = predict_read[0]
                unique = unique[0]

                if FLAGS.extension == 'fastq':
                    logits_prob = logits_prob[unique]
                if i + FLAGS.batch_size > reads_n:
                    predict_read = predict_read[:reads_n - i]
                    if FLAGS.extension == 'fastq':
                        logits_prob = logits_prob[:reads_n - i]
                if FLAGS.extension == 'fastq':
                    qs_list = np.concatenate((qs_list, logits_prob))
                reads += predict_read
            val.pop(name)  # Release the memory

            basecall_time = time.time() - start_time
            bpreads = [index2base(read) for read in reads]
            if FLAGS.extension == 'fastq':
                consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list)
                qs_string = qs(consensus, qs_consensus)
            else:
                consensus = simple_assembly(bpreads)
            c_bpread = index2base(np.argmax(consensus, axis=0))
            assembly_time = time.time() - start_time
            list_of_time = [start_time, reading_time,
                            basecall_time, assembly_time]
            write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension,
                         q_score=qs_string,global_setting=FLAGS)
    pbars.end()
示例#3
0
def evaluation():
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    logits, _ = inference(x, seq_length, training=training)
    predict = tf.nn.ctc_greedy_decoder(tf.transpose(logits, perm=[1, 0, 2]),
                                       seq_length,
                                       merge_repeated=True)
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model))
        if os.path.isdir(FLAGS.input):
            file_list = os.listdir(FLAGS.input)
            file_dir = FLAGS.input
        else:
            file_list = [os.path.basename(FLAGS.input)]
            file_dir = os.path.abspath(
                os.path.join(FLAGS.input, os.path.pardir))
        for name in file_list:
            start_time = time.time()
            if not name.endswith('.signal'):
                continue
            file_pre = os.path.splitext(name)[0]
            input_path = os.path.join(file_dir, name)
            eval_data = read_data_for_eval(input_path,
                                           FLAGS.start,
                                           seg_length=FLAGS.segment_len,
                                           step=FLAGS.jump)
            reads_n = eval_data.reads_n
            reading_time = time.time() - start_time
            reads = list()
            for i in range(0, reads_n, FLAGS.batch_size):
                batch_x, seq_len, _ = eval_data.next_batch(FLAGS.batch_size,
                                                           shuffle=False)
                batch_x = np.pad(batch_x,
                                 ((0, FLAGS.batch_size - len(batch_x)),
                                  (0, 0)),
                                 mode='constant')
                seq_len = np.pad(seq_len,
                                 ((0, FLAGS.batch_size - len(seq_len))),
                                 mode='constant')
                feed_dict = {x: batch_x, seq_length: seq_len, training: False}
                predict_val = sess.run(predict, feed_dict=feed_dict)
                predict_read = sparse2dense(predict_val)[0]
                if i + FLAGS.batch_size > reads_n:
                    predict_read = predict_read[:reads_n - i]
                reads += predict_read
            print(
                "Segment reads base calling finished, begin to assembly. %5.2f seconds"
                % (time.time() - start_time))
            basecall_time = time.time() - start_time
            bpreads = [index2base(read) for read in reads]
            concensus = simple_assembly(bpreads)
            c_bpread = index2base(np.argmax(concensus, axis=0))
            assembly_time = time.time() - start_time
            print("Assembly finished, begin output. %5.2f seconds" %
                  (time.time() - start_time))
            result_folder = os.path.join(FLAGS.output, 'result')
            seg_folder = os.path.join(FLAGS.output, 'segments')
            meta_folder = os.path.join(FLAGS.output, 'meta')
            if not os.path.exists(FLAGS.output):
                os.makedirs(FLAGS.output)
            if not os.path.exists(seg_folder):
                os.makedirs(seg_folder)
            if not os.path.exists(result_folder):
                os.makedirs(result_folder)
            if not os.path.exists(meta_folder):
                os.makedirs(meta_folder)
            path_con = os.path.join(result_folder, file_pre + '.fasta')
            path_reads = os.path.join(seg_folder, file_pre + '.fasta')
            path_meta = os.path.join(meta_folder, file_pre + '.meta')
            with open(path_reads, 'w+') as out_f, open(path_con,
                                                       'w+') as out_con:
                for indx, read in enumerate(bpreads):
                    out_f.write(file_pre + str(indx) + '\n')
                    out_f.write(read + '\n')
                out_con.write("{}\n{}".format(file_pre, c_bpread))
            with open(path_meta, 'w+') as out_meta:
                total_time = time.time() - start_time
                output_time = total_time - assembly_time
                assembly_time -= basecall_time
                basecall_time -= reading_time
                total_len = len(c_bpread)
                total_time = time.time() - start_time
                out_meta.write(
                    "# Reading Basecalling assembly output total rate(bp/s)\n")
                out_meta.write(
                    "%5.3f %5.3f %5.3f %5.3f %5.3f %5.3f\n" %
                    (reading_time, basecall_time, assembly_time, output_time,
                     total_time, total_len / total_time))
                out_meta.write(
                    "# read_len batch_size segment_len jump start_pos\n")
                out_meta.write("%d %d %d %d %d\n" %
                               (total_len, FLAGS.batch_size, FLAGS.segment_len,
                                FLAGS.jump, FLAGS.start))
                out_meta.write("# input_name model_name\n")
                out_meta.write("%s %s\n" % (FLAGS.input, FLAGS.model))
示例#4
0
def evaluation(signal_input, args):

	print("@ Loading U-net model ...")
	if args.model_param != "":
		params = load_modelParam(args.model_param)
	else:
		print("! Unable to load the model parameters, pls check!")
		exit()
	model_name = get_unet_model_name(params, args)
	unet_model = models.load_model("./experiment/model/weights/" + model_name+ ".h5", \
		custom_objects={'dice_coef_loss':dice_coef_loss, 'dice_coef':dice_coef, 'bce_dice_loss': bce_dice_loss, \
		'categorical_focal_loss_fixed':categorical_focal_loss(gamma=2., alpha=.25), \
		'ce_dice_loss': ce_dice_loss})

	if args.norm != "":
		print("@ Perform data normalization ... ")
		print("- Loading form %s" %(args.norm))

		pickle_in = open(args.norm,"rb")
		stat_dict = pickle.load(pickle_in)
		print("- Training Data statistics m=%f, s=%f" %(stat_dict["m"], stat_dict["s"]))

	#############################################################################
	print("@ loading signal files ...")
	if os.path.isdir(signal_input):
		file_list = os.listdir(signal_input)
		file_dir = signal_input
	else:
		file_list = [os.path.basename(signal_input)]
		file_dir = os.path.abspath(os.path.join(signal_input, os.path.pardir))

	## make the subfold
	if not os.path.exists(args.output):
		os.makedirs(args.output)
	if not os.path.exists(os.path.join(args.output, 'segments')):
		os.makedirs(os.path.join(args.output, 'segments'))
	if not os.path.exists(os.path.join(args.output, 'result')):
		os.makedirs(os.path.join(args.output, 'result'))
	if not os.path.exists(os.path.join(args.output, 'meta')):
		os.makedirs(os.path.join(args.output, 'meta'))

	# start processing files in the file list
	for name in file_list:
		start_time = time.time()
		if not name.endswith('.signal'):
			continue

		file_pre = os.path.splitext(name)[0]
		print("- Processing read %s" %(file_pre))
		input_path = os.path.join(file_dir, name)

		# reading files, take care about normalization issue. @@ chekcing the data normalization issues
		eval_data = read_data_for_eval(input_path, args.start, args.jump, args.segment_len)

		reads_n = eval_data.reads_n
		reading_time = time.time() - start_time

		reads = list()
		signals = np.empty((0, args.segment_len), dtype=np.float)

		# doing the base-calling for the loaded signals
		for i in range(0, reads_n, args.batch_size):
			# get input signals
			X, seq_len, _, _, _, _ = eval_data.next_batch(args.batch_size, shuffle=False)

			# call different basecallers here.
			X = X.reshape(X.shape[0], X.shape[1], 1).astype("float32")
			# normalization of the data
			
			if args.norm != "":
				X = (X - stat_dict["m"])/(stat_dict["s"])

			output = unet_basecaller(unet_model, X)
			reads += output

		# assembly the read results
		print("Segment reads base calling finished, begin to assembly. %5.2f seconds" % (time.time() - start_time))
		basecall_time = time.time() - start_time

		# doing simple assembly methods
		#print(reads)
		consensus = simple_assembly(reads)
		c_bpread = index2base_0(np.argmax(consensus, axis=0))

		assembly_time = time.time() - start_time
		print("Assembly finished, begin output. %5.2f seconds" % (time.time() - start_time))

		# writing the files to the fold
		list_of_time = [start_time, reading_time, basecall_time, assembly_time]
		write_output(args, reads, "".join(c_bpread), list_of_time, file_pre)