Ejemplo n.º 1
0
 def __init__(self,configure):
     self.pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"])
     self.x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
     self.seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
     self.training = tf.placeholder(tf.bool)
     self.logits, self.ratio = chiron_model.inference(
                                     self.x, 
                                     self.seq_length, 
                                     training=self.training,
                                     full_sequence_len = FLAGS.segment_len,
                                     configure = configure)
     self.config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads,
                             inter_op_parallelism_threads=FLAGS.threads)
     self.config.gpu_options.allow_growth = True
     self.logits_index = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
     self.logits_fname = tf.placeholder(tf.string, shape=[FLAGS.batch_size])
     self.logits_queue = tf.FIFOQueue(
         capacity=1000,
         dtypes=[tf.float32, tf.string, tf.int32, tf.int32],
         shapes=[self.logits.shape,self.logits_fname.shape,self.logits_index.shape, self.seq_length.shape]
     )
     self.logits_queue_size = self.logits_queue.size()
     self.logits_enqueue = self.logits_queue.enqueue((self.logits, self.logits_fname, self.logits_index, self.seq_length))
     self.logits_queue_close = self.logits_queue.close()
     ### Decoding logits into bases
     self.decode_predict_op, self.decode_prob_op, self.decoded_fname_op, self.decode_idx_op, self.decode_queue_size = decoding_queue(self.logits_queue)
     self.saver = tf.train.Saver(var_list=tf.trainable_variables()+tf.moving_average_variables())
Ejemplo n.º 2
0
def evaluation():
    pbars = multi_pbars(
        ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"])
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.model, 'model.json')
    model_configure = chiron_model.read_config(config_path)

    logits, ratio = chiron_model.inference(x,
                                           seq_length,
                                           training=training,
                                           full_sequence_len=FLAGS.segment_len,
                                           configure=model_configure)
    predict = tf.nn.ctc_beam_search_decoder(tf.transpose(logits,
                                                         perm=[1, 0, 2]),
                                            seq_length,
                                            merge_repeated=False,
                                            beam_width=FLAGS.beam)
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    sess = tf.train.MonitoredSession(
        session_creator=tf.train.ChiefSessionCreator(config=config))
    if os.path.isdir(FLAGS.input):
        file_list = os.listdir(FLAGS.input)
        file_dir = FLAGS.input
    else:
        file_list = [os.path.basename(FLAGS.input)]
        file_dir = os.path.abspath(os.path.join(FLAGS.input, os.path.pardir))
    for file in file_list:
        file_path = os.path.join(file_dir, file)
        eval_data = read_data_for_eval(file_path,
                                       FLAGS.start,
                                       seg_length=FLAGS.segment_len,
                                       step=FLAGS.jump)
        reads_n = eval_data.reads_n
        for i in range(0, reads_n, FLAGS.batch_size):
            batch_x, seq_len, _ = eval_data.next_batch(FLAGS.batch_size,
                                                       shuffle=False,
                                                       sig_norm=False)
            batch_x = np.pad(batch_x,
                             ((0, FLAGS.batch_size - len(batch_x)), (0, 0)),
                             mode='constant')
            seq_len = np.pad(seq_len, ((0, FLAGS.batch_size - len(seq_len))),
                             mode='constant')
            feed_dict = {
                x: batch_x,
                seq_length: np.round(seq_len / ratio).astype(np.int32),
                training: False,
            }
            logits_val, predict_val = sess.run([logits, predict],
                                               feed_dict=feed_dict)
Ejemplo n.º 3
0
def extract(root_folder,output_folder,raw_folder=None):
    global logger
    error_bars = multi_pbars([""]*5)
    run_record = Counter()
    batch_i = 1
    if not os.path.isdir(root_folder):
        raise IOError('Input directory does not found.')
    batch_folder = make_batch_folder(output_folder,batch_i)
    for dir_n,_,file_list in tf.gfile.Walk(root_folder):
     for file_n in file_list:
        if file_n.endswith('fast5'):
            file_prefix = file_n.split('.')[0]
#            output_file = output_folder + os.path.splitext(file_n)[0]
            file_n = os.path.join(dir_n,file_n)
            state, (raw_data, raw_data_array),(offset,digitisation,range_s) = extract_file(file_n)
            run_record[state] +=1
            if run_record[SUCCEED_TAG]>batch_i*FLAGS.batch:
                batch_i+=1
                batch_folder = make_batch_folder(output_folder,batch_i)
            common_errors = run_record.most_common(FLAGS.n_errors)
            total_errors = sum(run_record.values())
            for i in np.arange(min(FLAGS.n_errors,len(common_errors))):
                error_bars.update(i,
                                  title = common_errors[i][0],
                                  progress = common_errors[i][1],
                                  total = total_errors)
            error_bars.refresh()
            if state == SUCCEED_TAG:
                if FLAGS.unit:
                    raw_data=reunit(raw_data,offset,digitisation,range_s)
                with open(os.path.join(batch_folder,file_prefix+'.signal'),'w+') as f:
                    f.write('\n'.join([str(x) for x in raw_data]))
                with open(os.path.join(batch_folder,file_prefix+'.label'),'w+') as f:
                    for label in raw_data_array:
                        f.write(' '.join([str(x) for x in label]))
                        f.write('\n')
                logger.info("%s file transfered.   \n" % (file_n))
            else:
                logger.error("FAIL on %s file, because of error %s.   \n" % (file_n,state))
Ejemplo n.º 4
0
def evaluation():
    pbars = multi_pbars(
        ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"])
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.model, 'model.json')
    model_configure = chiron_model.read_config(config_path)

    logits, ratio = chiron_model.inference(x,
                                           seq_length,
                                           training=training,
                                           full_sequence_len=FLAGS.segment_len,
                                           configure=model_configure)
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    logits_index = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    logits_fname = tf.placeholder(tf.string, shape=[FLAGS.batch_size])
    logits_queue = tf.FIFOQueue(
        capacity=1000,
        dtypes=[tf.float32, tf.string, tf.int32, tf.int32],
        shapes=[
            logits.shape, logits_fname.shape, logits_index.shape,
            seq_length.shape
        ])
    logits_queue_size = logits_queue.size()
    logits_enqueue = logits_queue.enqueue(
        (logits, logits_fname, logits_index, seq_length))
    logits_queue_close = logits_queue.close()
    ### Decoding logits into bases
    decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(
        logits_queue)
    saver = tf.train.Saver(var_list=tf.trainable_variables() +
                           tf.moving_average_variables())
    with tf.train.MonitoredSession(
            session_creator=tf.train.ChiefSessionCreator(
                config=config)) as sess:
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model))
        if os.path.isdir(FLAGS.input):
            file_list = os.listdir(FLAGS.input)
            file_dir = FLAGS.input
        else:
            file_list = [os.path.basename(FLAGS.input)]
            file_dir = os.path.abspath(
                os.path.join(FLAGS.input, os.path.pardir))
        file_n = len(file_list)
        pbars.update(2, total=file_n)
        pbars.update(3, total=file_n)
        if not os.path.exists(FLAGS.output):
            os.makedirs(FLAGS.output)
        if not os.path.exists(os.path.join(FLAGS.output, 'segments')):
            os.makedirs(os.path.join(FLAGS.output, 'segments'))
        if not os.path.exists(os.path.join(FLAGS.output, 'result')):
            os.makedirs(os.path.join(FLAGS.output, 'result'))
        if not os.path.exists(os.path.join(FLAGS.output, 'meta')):
            os.makedirs(os.path.join(FLAGS.output, 'meta'))

        def worker_fn():
            batch_x = np.asarray([[]]).reshape(0, FLAGS.segment_len)
            seq_len = np.asarray([])
            logits_idx = np.asarray([])
            logits_fn = np.asarray([])
            for f_i, name in enumerate(file_list):
                if not name.endswith('.signal'):
                    continue
                input_path = os.path.join(file_dir, name)
                eval_data = read_data_for_eval(input_path,
                                               FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
                reads_n = eval_data.reads_n
                pbars.update(0, total=reads_n, progress=0)
                pbars.update_bar()
                i = 0
                while (eval_data.epochs_completed == 0):
                    current_batch, current_seq_len, _ = eval_data.next_batch(
                        FLAGS.batch_size - len(batch_x), shuffle=False)
                    current_n = len(current_batch)
                    batch_x = np.concatenate((batch_x, current_batch), axis=0)
                    seq_len = np.concatenate((seq_len, current_seq_len),
                                             axis=0)
                    logits_idx = np.concatenate((logits_idx, [i] * current_n),
                                                axis=0)
                    logits_fn = np.concatenate((logits_fn, [name] * current_n),
                                               axis=0)
                    i += current_n
                    if len(batch_x) < FLAGS.batch_size:
                        pbars.update(0, progress=i)
                        pbars.update_bar()
                        continue
                    feed_dict = {
                        x: batch_x,
                        seq_length: np.round(seq_len / ratio).astype(np.int32),
                        training: False,
                        logits_index: logits_idx,
                        logits_fname: logits_fn,
                    }
                    #Training: Set it to  True for a temporary fix of the batch normalization problem: https://github.com/haotianteng/Chiron/commit/8fce3a3b4dac8e9027396bb8c9152b7b5af953ce
                    #TODO: change the training FLAG back to False after the new model has been trained.
                    sess.run(logits_enqueue, feed_dict=feed_dict)
                    batch_x = np.asarray([[]]).reshape(0, FLAGS.segment_len)
                    seq_len = np.asarray([])
                    logits_idx = np.asarray([])
                    logits_fn = np.asarray([])
                    pbars.update(0, progress=i)
                    pbars.update_bar()
                pbars.update(2, progress=f_i + 1)
                pbars.update_bar()
            ### All files has been processed.
            batch_n = len(batch_x)
            if batch_n > 0:
                pad_width = FLAGS.batch_size - batch_n
                batch_x = np.pad(batch_x, ((0, pad_width), (0, 0)),
                                 mode='wrap')
                seq_len = np.pad(seq_len, ((0, pad_width)), mode='wrap')
                logits_idx = np.pad(logits_idx, (0, pad_width),
                                    mode='constant',
                                    constant_values=-1)
                logits_fn = np.pad(logits_fn, (0, pad_width),
                                   mode='constant',
                                   constant_values='')
                sess.run(logits_enqueue,
                         feed_dict={
                             x: batch_x,
                             seq_length:
                             np.round(seq_len / ratio).astype(np.int32),
                             training: False,
                             logits_index: logits_idx,
                             logits_fname: logits_fn,
                         })
            sess.run(logits_queue_close)


#

        worker = threading.Thread(target=worker_fn, args=())
        worker.setDaemon(True)
        worker.start()

        val = defaultdict(
            dict)  # We could read vals out of order, that's why it's a dict
        for f_i, name in enumerate(file_list):
            start_time = time.time()
            if not name.endswith('.signal'):
                continue
            file_pre = os.path.splitext(name)[0]
            input_path = os.path.join(file_dir, name)
            if FLAGS.mode == 'rna':
                eval_data = read_data_for_eval(input_path,
                                               FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
            else:
                eval_data = read_data_for_eval(input_path,
                                               FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
            reads_n = eval_data.reads_n
            pbars.update(1, total=reads_n, progress=0)
            pbars.update_bar()
            reading_time = time.time() - start_time
            reads = list()
            if 'total_count' not in val[name].keys():
                val[name]['total_count'] = 0
            if 'index_list' not in val[name].keys():
                val[name]['index_list'] = []
            while True:
                l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size])
                if val[name]['total_count'] == reads_n:
                    pbars.update(1, progress=val[name]['total_count'])
                    break
                decode_ops = [
                    decoded_fname_op, decode_idx_op, decode_predict_op,
                    decode_prob_op
                ]
                decoded_fname, i, predict_val, logits_prob = sess.run(
                    decode_ops, feed_dict={training: False})
                decoded_fname = np.asarray(
                    [x.decode("UTF-8") for x in decoded_fname])
                ##Have difficulties integrate it into the tensorflow graph, as the number of file names in a batch is variable.
                ##And for loop can't be implemented as the eager execution is disabled due to the use of queue.
                uniq_fname, uniq_fn_idx = np.unique(decoded_fname,
                                                    return_index=True)
                for fn_idx, fn in enumerate(uniq_fname):
                    i = uniq_fn_idx[fn_idx]
                    if fn != '':
                        occurance = np.where(decoded_fname == fn)[0]
                        start = occurance[0]
                        end = occurance[-1] + 1
                        assert (len(occurance) == end - start)
                        if 'total_count' not in val[fn].keys():
                            val[fn]['total_count'] = 0
                        if 'index_list' not in val[fn].keys():
                            val[fn]['index_list'] = []
                        val[fn]['total_count'] += (end - start)
                        val[fn]['index_list'].append(i)
                        sliced_sparse = slice_ctc_decoding_result(
                            predict_val, start, end)
                        val[fn][i] = (sliced_sparse,
                                      logits_prob[decoded_fname == fn])
                pbars.update(1, progress=val[name]['total_count'])
                pbars.update_bar()

            pbars.update(3, progress=f_i + 1)
            pbars.update_bar()
            qs_list = np.empty((0, 1), dtype=np.float)
            qs_string = None
            for i in np.sort(val[name]['index_list']):
                predict_val, logits_prob = val[name][i]
                predict_read, unique = sparse2dense(predict_val)
                predict_read = predict_read[0]
                unique = unique[0]

                if FLAGS.extension == 'fastq':
                    logits_prob = logits_prob[unique]
                if FLAGS.extension == 'fastq':
                    qs_list = np.concatenate((qs_list, logits_prob))
                reads += predict_read
            val.pop(name)  # Release the memory
            basecall_time = time.time() - start_time
            bpreads = [index2base(read) for read in reads]
            js_ratio = FLAGS.jump / FLAGS.segment_len
            kernal = get_assembler_kernal(FLAGS.jump, FLAGS.segment_len)
            if FLAGS.extension == 'fastq':
                consensus, qs_consensus = simple_assembly_qs(bpreads,
                                                             qs_list,
                                                             js_ratio,
                                                             kernal=kernal)
                qs_string = qs(consensus, qs_consensus)
            else:
                consensus = simple_assembly(bpreads, js_ratio, kernal=kernal)
            c_bpread = index2base(np.argmax(consensus, axis=0))
            assembly_time = time.time() - start_time
            list_of_time = [
                start_time, reading_time, basecall_time, assembly_time
            ]
            write_output(bpreads,
                         c_bpread,
                         list_of_time,
                         file_pre,
                         concise=FLAGS.concise,
                         suffix=FLAGS.extension,
                         q_score=qs_string,
                         global_setting=FLAGS)
    pbars.end()
Ejemplo n.º 5
0
def read_tfrecord(data_dir, 
                  tfrecord, 
                  h5py_file_path=None, 
                  seq_length=300, 
                  k_mer=1, 
                  max_segments_num=None,
                  skip_start = 10):
    ###Read from raw data
    count_bar = progress.multi_pbars("Extract tfrecords")
    if max_segments_num is None:
        max_segments_num = FLAGS.max_segments_number
        count_bar.update(0,progress = 0,total = max_segments_num)
    if h5py_file_path is None:
        h5py_file_path = tempfile.mkdtemp() + '/temp_record.hdf5'
    else:
        try:
            os.remove(os.path.abspath(h5py_file_path))
        except:
            pass
        if not os.path.isdir(os.path.dirname(os.path.abspath(h5py_file_path))):
            os.mkdir(os.path.dirname(os.path.abspath(h5py_file_path)))
    with h5py.File(h5py_file_path, "a") as hdf5_record:
        event_h = hdf5_record.create_dataset('event/record', dtype='float32', shape=(0, seq_length),
                                             maxshape=(None, seq_length))
        event_length_h = hdf5_record.create_dataset('event/length', dtype='int32', shape=(0,), maxshape=(None,),
                                                    chunks=True)
        label_h = hdf5_record.create_dataset('label/record', dtype='int32', shape=(0, 0), maxshape=(None, seq_length))
        label_length_h = hdf5_record.create_dataset('label/length', dtype='int32', shape=(0,), maxshape=(None,))
        event = biglist(data_handle=event_h, max_len=FLAGS.MAXLEN)
        event_length = biglist(data_handle=event_length_h, max_len=FLAGS.MAXLEN)
        label = biglist(data_handle=label_h, max_len=FLAGS.MAXLEN)
        label_length = biglist(data_handle=label_length_h, max_len=FLAGS.MAXLEN)
        count = 0
        file_count = 0

        tfrecords_filename = data_dir + tfrecord
        record_iterator = tf.python_io.tf_record_iterator(path=tfrecords_filename)

        for string_record in record_iterator:
            
            example = tf.train.Example()
            example.ParseFromString(string_record)
            
            raw_data_string = (example.features.feature['raw_data']
                                          .bytes_list
                                          .value[0])
            features_string = (example.features.feature['features']
                                        .bytes_list
                                        .value[0])
            fn_string = (example.features.feature['fname'].bytes_list.value[0])

            raw_data = np.frombuffer(raw_data_string, dtype=SIGNAL_DTYPE)
            
            features_data = np.frombuffer(features_string, dtype='S8')
            # grouping the whole array into sub-array with size = 3
            group_size = 3
            features_data = [features_data[n:n+group_size] for n in range(0, len(features_data), group_size)]
            f_signal = read_signal_tfrecord(raw_data)

            if len(f_signal) == 0:
                continue
            #try:
            f_label = read_label_tfrecord(features_data, skip_start=skip_start, window_n=(k_mer - 1) / 2)
            #except:
            #    sys.stdout.write("Read the label fail.Skipped.")
            #    continue
            tmp_event, tmp_event_length, tmp_label, tmp_label_length = read_raw(f_signal, f_label, seq_length)
            event += tmp_event
            event_length += tmp_event_length
            label += tmp_label
            label_length += tmp_label_length
            del tmp_event
            del tmp_event_length
            del tmp_label
            del tmp_label_length
            count = len(event)
            if file_count % 10 == 0:
                if max_segments_num is not None:
                    count_bar.update(0,progress = count,total = max_segments_num)
                    count_bar.update_bar()
                    if len(event) > max_segments_num:
                        event.resize(max_segments_num)
                        label.resize(max_segments_num)
                        event_length.resize(max_segments_num)

                        label_length.resize(max_segments_num)
                        break
                else:
                    count_bar.update(0,progress = count,total = count)
                    count_bar.update_bar()
            file_count += 1

        if event.cache:
            train = read_cache_dataset(h5py_file_path)
        else:
            train = DataSet(event=event, event_length=event_length, label=label, label_length=label_length)
        count_bar.end()
        return train
Ejemplo n.º 6
0
def read_raw_data_sets(data_dir,
                       h5py_file_path=None,
                       seq_length=300,
                       k_mer=1,
                       max_segments_num=FLAGS.max_segments_number,
                       skip_start=10):
    ###Read from raw data
    count_bar = progress.multi_pbars("Extract tfrecords")
    if max_segments_num is None:
        max_segments_num = FLAGS.max_segments_number
        count_bar.update(0, progress=0, total=max_segments_num)
    if h5py_file_path is None:
        h5py_file_path = tempfile.mkdtemp() + '/temp_record.hdf5'
    else:
        try:
            os.remove(os.path.abspath(h5py_file_path))
        except:
            pass
        if not os.path.isdir(os.path.dirname(os.path.abspath(h5py_file_path))):
            os.mkdir(os.path.dirname(os.path.abspath(h5py_file_path)))
    with h5py.File(h5py_file_path, "a") as hdf5_record:
        event_h = hdf5_record.create_dataset('event/record',
                                             dtype='float32',
                                             shape=(0, seq_length),
                                             maxshape=(None, seq_length))
        event_length_h = hdf5_record.create_dataset('event/length',
                                                    dtype='int32',
                                                    shape=(0, ),
                                                    maxshape=(None, ),
                                                    chunks=True)
        label_h = hdf5_record.create_dataset('label/record',
                                             dtype='int32',
                                             shape=(0, 0),
                                             maxshape=(None, seq_length))
        label_length_h = hdf5_record.create_dataset('label/length',
                                                    dtype='int32',
                                                    shape=(0, ),
                                                    maxshape=(None, ))
        event = biglist(data_handle=event_h, max_len=FLAGS.MAXLEN)
        event_length = biglist(data_handle=event_length_h,
                               max_len=FLAGS.MAXLEN)
        label = biglist(data_handle=label_h, max_len=FLAGS.MAXLEN)
        label_length = biglist(data_handle=label_length_h,
                               max_len=FLAGS.MAXLEN)
        count = 0
        file_count = 0
        for root, dirs, files in os.walk(data_dir, topdown=False):
            for name in files:
                if name.endswith(".signal"):
                    file_pre = os.path.splitext(name)[0]
                    signal_f = os.path.join(root, name)
                    f_signal = read_signal(signal_f, normalize=FLAGS.sig_norm)
                    label_f = os.path.join(root, file_pre + '.label')
                    if len(f_signal) == 0:
                        continue
                    try:
                        f_label = read_label(label_f,
                                             skip_start=skip_start,
                                             window_n=int((k_mer - 1) / 2))
                    except:
                        sys.stdout.write("Read the label %s fail.Skipped." %
                                         (name))
                        continue
                    try:
                        tmp_event, tmp_event_length, tmp_label, tmp_label_length = read_raw(
                            f_signal, f_label, seq_length)
                    except Exception as e:
                        print(
                            "Extract label from %s fail, label position exceed max signal length."
                            % (label_f))
                        raise e
                    event += tmp_event
                    event_length += tmp_event_length
                    label += tmp_label
                    label_length += tmp_label_length
                    del tmp_event
                    del tmp_event_length
                    del tmp_label
                    del tmp_label_length
                    count = len(event)
                    if file_count % 10 == 0:
                        if max_segments_num is not None:
                            count_bar.update(0,
                                             progress=count,
                                             total=max_segments_num)
                            count_bar.update_bar()
                            if len(event) > max_segments_num:
                                event.resize(max_segments_num)
                                label.resize(max_segments_num)
                                event_length.resize(max_segments_num)

                                label_length.resize(max_segments_num)
                                break
                        else:
                            count_bar.update(0, progress=count, total=count)
                            count_bar.update_bar()
                    file_count += 1
        if event.cache:
            event.save_rest()
            event_length.save_rest()
            label.save_rest()
            label_length.save_rest()
            train = read_cache_dataset(h5py_file_path)
        else:
            event.save()
            event_length.save()
            label.save()
            label_length.save()
            train = read_cache_dataset(h5py_file_path)
        count_bar.end()
    return train
Ejemplo n.º 7
0
def evaluation():
    pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"])
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.model,'model.json')
    model_configure = chiron_model.read_config(config_path)

    logits, ratio = chiron_model.inference(
                                    x, 
                                    seq_length, 
                                    training=training,
                                    full_sequence_len = FLAGS.segment_len,
                                    configure = model_configure)
    config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    logits_index = tf.placeholder(tf.int32, shape=())
    logits_fname = tf.placeholder(tf.string, shape=())
    logits_queue = tf.FIFOQueue(
        capacity=1000,
        dtypes=[tf.float32, tf.string, tf.int32, tf.int32],
        shapes=[logits.shape, logits_fname.shape, logits_index.shape, seq_length.shape]
    )
    logits_queue_size = logits_queue.size()
    logits_enqueue = logits_queue.enqueue((logits, logits_fname, logits_index, seq_length))
    logits_queue_close = logits_queue.close()
    ### Decoding logits into bases
    decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(logits_queue)
    saver = tf.train.Saver()
    with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(config=config)) as sess:
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model))
        if os.path.isdir(FLAGS.input):
            file_list = os.listdir(FLAGS.input)
            file_dir = FLAGS.input
        else:
            file_list = [os.path.basename(FLAGS.input)]
            file_dir = os.path.abspath(
                os.path.join(FLAGS.input, os.path.pardir))
        file_n = len(file_list)
        pbars.update(2,total = file_n)
        pbars.update(3,total = file_n)
        if not os.path.exists(FLAGS.output):
            os.makedirs(FLAGS.output)
        if not os.path.exists(os.path.join(FLAGS.output, 'segments')):
            os.makedirs(os.path.join(FLAGS.output, 'segments'))
        if not os.path.exists(os.path.join(FLAGS.output, 'result')):
            os.makedirs(os.path.join(FLAGS.output, 'result'))
        if not os.path.exists(os.path.join(FLAGS.output, 'meta')):
            os.makedirs(os.path.join(FLAGS.output, 'meta'))
        def worker_fn():
            for f_i, name in enumerate(file_list):
                if not name.endswith('.signal'):
                    continue
                input_path = os.path.join(file_dir, name)
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
                reads_n = eval_data.reads_n
                pbars.update(0,total = reads_n,progress = 0)
                pbars.update_bar()
                for i in range(0, reads_n, FLAGS.batch_size):
                    batch_x, seq_len, _ = eval_data.next_batch(
                        FLAGS.batch_size, shuffle=False, sig_norm=False)
                    batch_x = np.pad(
                        batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant')
                    seq_len = np.pad(
                        seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant')
                    feed_dict = {
                        x: batch_x,
                        seq_length: seq_len,
                        training: False,
                        logits_index:i,
                        logits_fname: name,
                    }
Ejemplo n.º 8
0
def evaluation():
    pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"])
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.model,'model.json')
    model_configure = chiron_model.read_config(config_path)

    logits, ratio = chiron_model.inference(
                                    x, 
                                    seq_length, 
                                    training=training,
                                    full_sequence_len = FLAGS.segment_len,
                                    configure = model_configure)
    config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    logits_index = tf.placeholder(tf.int32, shape=())
    logits_fname = tf.placeholder(tf.string, shape=())
    logits_queue = tf.FIFOQueue(
        capacity=1000,
        dtypes=[tf.float32, tf.string, tf.int32, tf.int32],
        shapes=[logits.shape,logits_fname.shape,logits_index.shape, seq_length.shape]
    )
    logits_queue_size = logits_queue.size()
    logits_enqueue = logits_queue.enqueue((logits, logits_fname, logits_index, seq_length))
    logits_queue_close = logits_queue.close()
    ### Decoding logits into bases
    decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(logits_queue)
    saver = tf.train.Saver()
    with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(config=config)) as sess:
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model))
        if os.path.isdir(FLAGS.input):
            file_list = os.listdir(FLAGS.input)
            file_dir = FLAGS.input
        else:
            file_list = [os.path.basename(FLAGS.input)]
            file_dir = os.path.abspath(
                os.path.join(FLAGS.input, os.path.pardir))
        file_n = len(file_list)
        pbars.update(2,total = file_n)
        pbars.update(3,total = file_n)
        if not os.path.exists(FLAGS.output):
            os.makedirs(FLAGS.output)
        if not os.path.exists(os.path.join(FLAGS.output, 'segments')):
            os.makedirs(os.path.join(FLAGS.output, 'segments'))
        if not os.path.exists(os.path.join(FLAGS.output, 'result')):
            os.makedirs(os.path.join(FLAGS.output, 'result'))
        if not os.path.exists(os.path.join(FLAGS.output, 'meta')):
            os.makedirs(os.path.join(FLAGS.output, 'meta'))
        def worker_fn():
            for f_i, name in enumerate(file_list):
                if not name.endswith('.signal'):
                    continue
                input_path = os.path.join(file_dir, name)
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
                reads_n = eval_data.reads_n
                pbars.update(0,total = reads_n,progress = 0)
                pbars.update_bar()
                for i in range(0, reads_n, FLAGS.batch_size):
                    batch_x, seq_len, _ = eval_data.next_batch(
                        FLAGS.batch_size, shuffle=False, sig_norm=False)
                    batch_x = np.pad(
                        batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant')
                    seq_len = np.pad(
                        seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant')
                    feed_dict = {
                        x: batch_x,
                        seq_length: np.round(seq_len/ratio).astype(np.int32),
                        training: False,
                        logits_index:i,
                        logits_fname: name,
                    }
                    sess.run(logits_enqueue,feed_dict=feed_dict)
                    pbars.update(0,progress=i+FLAGS.batch_size)
                    pbars.update_bar()
                pbars.update(2,progress = f_i+1)
                pbars.update_bar()
            sess.run(logits_queue_close)

        worker = threading.Thread(target=worker_fn,args=() )
        worker.setDaemon(True)
        worker.start()

        val = defaultdict(dict)  # We could read vals out of order, that's why it's a dict
        for f_i, name in enumerate(file_list):
            start_time = time.time()
            if not name.endswith('.signal'):
                continue
            file_pre = os.path.splitext(name)[0]
            input_path = os.path.join(file_dir, name)
            if FLAGS.mode == 'rna':
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                           seg_length=FLAGS.segment_len,
                                           step=FLAGS.jump)
            else:
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                           seg_length=FLAGS.segment_len,
                                           step=FLAGS.jump)
            reads_n = eval_data.reads_n
            pbars.update(1,total = reads_n,progress = 0)
            pbars.update_bar()
            reading_time = time.time() - start_time
            reads = list()

            N = len(range(0, reads_n, FLAGS.batch_size))
            while True:
                l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size])
                decode_ops = [decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op]
                decoded_fname, i, predict_val, logits_prob = sess.run(decode_ops, feed_dict={training: False})
                decoded_fname = decoded_fname.decode("UTF-8")
                val[decoded_fname][i] = (predict_val, logits_prob)               
                pbars.update(1,progress = len(val[name])*FLAGS.batch_size)
                pbars.update_bar()
                if len(val[name]) == N:
                    break

            pbars.update(3,progress = f_i+1)
            pbars.update_bar()
            qs_list = np.empty((0, 1), dtype=np.float)
            qs_string = None
            for i in range(0, reads_n, FLAGS.batch_size):
                predict_val, logits_prob = val[name][i]
                predict_read, unique = sparse2dense(predict_val)
                predict_read = predict_read[0]
                unique = unique[0]

                if FLAGS.extension == 'fastq':
                    logits_prob = logits_prob[unique]
                if i + FLAGS.batch_size > reads_n:
                    predict_read = predict_read[:reads_n - i]
                    if FLAGS.extension == 'fastq':
                        logits_prob = logits_prob[:reads_n - i]
                if FLAGS.extension == 'fastq':
                    qs_list = np.concatenate((qs_list, logits_prob))
                reads += predict_read
            val.pop(name)  # Release the memory

            basecall_time = time.time() - start_time
            bpreads = [index2base(read) for read in reads]
            js_ratio = FLAGS.jump/FLAGS.segment_len
            if FLAGS.extension == 'fastq':
                consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list,js_ratio)
                qs_string = qs(consensus, qs_consensus)
            else:
                consensus = simple_assembly(bpreads,js_ratio)
            c_bpread = index2base(np.argmax(consensus, axis=0))
            assembly_time = time.time() - start_time
            list_of_time = [start_time, reading_time,
                            basecall_time, assembly_time]
            write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension,
                         q_score=qs_string,global_setting=FLAGS)
    pbars.end()
Ejemplo n.º 9
0
def do_inference():
    """Tests PredictionService with concurrent requests.    
    Raises:
    IOError: An error occurred processing test data set.
    """
    if FLAGS.mode == 'dna':
        CONF = DNA_CONF()
    elif FLAGS.mode == 'rna':
        CONF = RNA_CONF()
    else:
        raise ValueError("Mode has to be either rna or dna.")
    make_dirs(FLAGS.output)
    FLAGS.segment_len = CONF.SEGMENT_LEN
    FLAGS.jump = CONF.JUMP
    FLAGS.start = CONF.START
    pbars = multi_pbars(["Request Submit:", "Request finished"])
    channel = grpc.insecure_channel(FLAGS.server)
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    request = predict_pb2.PredictRequest()
    request.model_spec.name = 'chiron'
    request.model_spec.signature_name = 'predicted_sequences'
    collector = _Result_Collection(concurrency=FLAGS.concurrency)
    file_list = gen_file_list(FLAGS.input)
    batch_iterator = data_iterator(file_list)

    def submit_fn():
        for batch_x, seq_len, i, f, N, reads_n in batch_iterator:
            seq_len = np.reshape(seq_len, (seq_len.shape[0], 1))
            combined_input = np.concatenate((batch_x, seq_len),
                                            axis=1).astype(np.float32)
            request.inputs['combined_inputs'].CopyFrom(
                tf.contrib.util.make_tensor_proto(
                    combined_input,
                    shape=[FLAGS.batch_size, CONF.SEGMENT_LEN + 1]))
            collector.throttle()
            result_future = stub.Predict.future(request, 100.0)  # 5 seconds
            result_future.add_done_callback(_post_process(collector, i, f, N))
            pbars.update(0, total=reads_n, progress=(i + 1) * FLAGS.batch_size)
            pbars.update_bar()

    submiter = threading.Thread(target=submit_fn, args=())
    submiter.setDaemon(True)
    submiter.start()
    pbars.update(1, total=len(file_list))
    pbars.update_bar()
    while not collector.all_done():
        if len(collector._done) > 0:
            qs_string = None
            f_p = collector._done[0]
            reads, probs = collector.pop_out(f_p)
            bpreads = [index2base(read) for read in reads]
            consensus, qs_consensus = simple_assembly_qs(bpreads, probs)
            qs_string = qs(consensus, qs_consensus)
            c_bpread = index2base(np.argmax(consensus, axis=0))
            file_pre = os.path.basename(os.path.splitext(f_p)[0])
            write_output(bpreads,
                         c_bpread, [np.NaN] * 4,
                         file_pre,
                         concise=FLAGS.concise,
                         suffix=FLAGS.extension,
                         q_score=qs_string,
                         global_setting=FLAGS)
            pbars.update(1, progress=pbars.progress[1] + 1)
            pbars.update_bar()
Ejemplo n.º 10
0
        default=30,
        help=
        "Beam width used in beam search decoder, default is 0, in which a greedy decoder is used. Recommend width:100, Large beam width give better decoding result but require longer decoding time."
    )
    parser.add_argument(
        '--concise',
        action='store_true',
        help=
        "Concisely output the result, the meta and segments files will not be output."
    )
    parser.add_argument('--mode',
                        default='dna',
                        help="Output mode, can be chosen from dna or rna.")
    FLAGS = parser.parse_args(sys.argv[1:])

    pbars = multi_pbars(
        ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"])
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.model, 'model.json')
    model_configure = chiron_model.read_config(config_path)
    logits, ratio = chiron_model.inference(x,
                                           seq_length,
                                           training=training,
                                           full_sequence_len=FLAGS.segment_len,
                                           configure=model_configure)
    prorbs = tf.nn.softmax(logits)
    predict = tf.nn.ctc_beam_search_decoder(tf.transpose(logits,
                                                         perm=[1, 0, 2]),
                                            seq_length,
                                            merge_repeated=False,