Example #1
0
def evaluation():
    pbars = multi_pbars(
        ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"])
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.model, 'model.json')
    model_configure = chiron_model.read_config(config_path)

    logits, ratio = chiron_model.inference(x,
                                           seq_length,
                                           training=training,
                                           full_sequence_len=FLAGS.segment_len,
                                           configure=model_configure)
    predict = tf.nn.ctc_beam_search_decoder(tf.transpose(logits,
                                                         perm=[1, 0, 2]),
                                            seq_length,
                                            merge_repeated=False,
                                            beam_width=FLAGS.beam)
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    sess = tf.train.MonitoredSession(
        session_creator=tf.train.ChiefSessionCreator(config=config))
    if os.path.isdir(FLAGS.input):
        file_list = os.listdir(FLAGS.input)
        file_dir = FLAGS.input
    else:
        file_list = [os.path.basename(FLAGS.input)]
        file_dir = os.path.abspath(os.path.join(FLAGS.input, os.path.pardir))
    for file in file_list:
        file_path = os.path.join(file_dir, file)
        eval_data = read_data_for_eval(file_path,
                                       FLAGS.start,
                                       seg_length=FLAGS.segment_len,
                                       step=FLAGS.jump)
        reads_n = eval_data.reads_n
        for i in range(0, reads_n, FLAGS.batch_size):
            batch_x, seq_len, _ = eval_data.next_batch(FLAGS.batch_size,
                                                       shuffle=False,
                                                       sig_norm=False)
            batch_x = np.pad(batch_x,
                             ((0, FLAGS.batch_size - len(batch_x)), (0, 0)),
                             mode='constant')
            seq_len = np.pad(seq_len, ((0, FLAGS.batch_size - len(seq_len))),
                             mode='constant')
            feed_dict = {
                x: batch_x,
                seq_length: np.round(seq_len / ratio).astype(np.int32),
                training: False,
            }
            logits_val, predict_val = sess.run([logits, predict],
                                               feed_dict=feed_dict)
Example #2
0
def output_list(x, seq_length):
    training = tf.constant(False, dtype=tf.bool, name='Training')
    config_path = os.path.join(FLAGS.model, 'model.json')
    model_configure = read_config(config_path)
    logits, ratio = inference(x,
                              seq_length,
                              training=training,
                              full_sequence_len=FLAGS.segment_len,
                              configure=model_configure)
    ratio = tf.constant(ratio, dtype=tf.float32, shape=[])
    seq_length_r = tf.cast(
        tf.round(tf.cast(seq_length, dtype=tf.float32) / ratio), tf.int32)
    prob_logits = path_prob(logits)
    predict, log_prob = tf.nn.ctc_beam_search_decoder(
        tf.transpose(logits, perm=[1, 0, 2]),
        seq_length_r,
        merge_repeated=True,
        beam_width=FLAGS.beam_width)
    return predict[0], logits, prob_logits, log_prob
Example #3
0
def input_output_list():
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    model_configure = chiron_model.read_config(FLAGS.config_path)
    logits, _ = inference(x,
                          seq_length,
                          training=training,
                          full_sequence_len=FLAGS.segment_len,
                          configure=model_configure)
    predict = tf.nn.ctc_greedy_decoder(tf.transpose(logits, perm=[1, 0, 2]),
                                       seq_length,
                                       merge_repeated=True)
    input_dict = {'x': x, 'seq_length': seq_length, 'training': training}
    output_dict = {
        'decoded_indices': predict[0][0].indices,
        'decoded_values': predict[0][0].values,
        'neg_sum_logits': predict[1]
    }
    return input_dict, output_dict
Example #4
0
def evaluation():
    pbars = multi_pbars(
        ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"])
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.model, 'model.json')
    model_configure = chiron_model.read_config(config_path)

    logits, ratio = chiron_model.inference(x,
                                           seq_length,
                                           training=training,
                                           full_sequence_len=FLAGS.segment_len,
                                           configure=model_configure)
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    logits_index = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    logits_fname = tf.placeholder(tf.string, shape=[FLAGS.batch_size])
    logits_queue = tf.FIFOQueue(
        capacity=1000,
        dtypes=[tf.float32, tf.string, tf.int32, tf.int32],
        shapes=[
            logits.shape, logits_fname.shape, logits_index.shape,
            seq_length.shape
        ])
    logits_queue_size = logits_queue.size()
    logits_enqueue = logits_queue.enqueue(
        (logits, logits_fname, logits_index, seq_length))
    logits_queue_close = logits_queue.close()
    ### Decoding logits into bases
    decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(
        logits_queue)
    saver = tf.train.Saver(var_list=tf.trainable_variables() +
                           tf.moving_average_variables())
    with tf.train.MonitoredSession(
            session_creator=tf.train.ChiefSessionCreator(
                config=config)) as sess:
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model))
        if os.path.isdir(FLAGS.input):
            file_list = os.listdir(FLAGS.input)
            file_dir = FLAGS.input
        else:
            file_list = [os.path.basename(FLAGS.input)]
            file_dir = os.path.abspath(
                os.path.join(FLAGS.input, os.path.pardir))
        file_n = len(file_list)
        pbars.update(2, total=file_n)
        pbars.update(3, total=file_n)
        if not os.path.exists(FLAGS.output):
            os.makedirs(FLAGS.output)
        if not os.path.exists(os.path.join(FLAGS.output, 'segments')):
            os.makedirs(os.path.join(FLAGS.output, 'segments'))
        if not os.path.exists(os.path.join(FLAGS.output, 'result')):
            os.makedirs(os.path.join(FLAGS.output, 'result'))
        if not os.path.exists(os.path.join(FLAGS.output, 'meta')):
            os.makedirs(os.path.join(FLAGS.output, 'meta'))

        def worker_fn():
            batch_x = np.asarray([[]]).reshape(0, FLAGS.segment_len)
            seq_len = np.asarray([])
            logits_idx = np.asarray([])
            logits_fn = np.asarray([])
            for f_i, name in enumerate(file_list):
                if not name.endswith('.signal'):
                    continue
                input_path = os.path.join(file_dir, name)
                eval_data = read_data_for_eval(input_path,
                                               FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
                reads_n = eval_data.reads_n
                pbars.update(0, total=reads_n, progress=0)
                pbars.update_bar()
                i = 0
                while (eval_data.epochs_completed == 0):
                    current_batch, current_seq_len, _ = eval_data.next_batch(
                        FLAGS.batch_size - len(batch_x), shuffle=False)
                    current_n = len(current_batch)
                    batch_x = np.concatenate((batch_x, current_batch), axis=0)
                    seq_len = np.concatenate((seq_len, current_seq_len),
                                             axis=0)
                    logits_idx = np.concatenate((logits_idx, [i] * current_n),
                                                axis=0)
                    logits_fn = np.concatenate((logits_fn, [name] * current_n),
                                               axis=0)
                    i += current_n
                    if len(batch_x) < FLAGS.batch_size:
                        pbars.update(0, progress=i)
                        pbars.update_bar()
                        continue
                    feed_dict = {
                        x: batch_x,
                        seq_length: np.round(seq_len / ratio).astype(np.int32),
                        training: False,
                        logits_index: logits_idx,
                        logits_fname: logits_fn,
                    }
                    #Training: Set it to  True for a temporary fix of the batch normalization problem: https://github.com/haotianteng/Chiron/commit/8fce3a3b4dac8e9027396bb8c9152b7b5af953ce
                    #TODO: change the training FLAG back to False after the new model has been trained.
                    sess.run(logits_enqueue, feed_dict=feed_dict)
                    batch_x = np.asarray([[]]).reshape(0, FLAGS.segment_len)
                    seq_len = np.asarray([])
                    logits_idx = np.asarray([])
                    logits_fn = np.asarray([])
                    pbars.update(0, progress=i)
                    pbars.update_bar()
                pbars.update(2, progress=f_i + 1)
                pbars.update_bar()
            ### All files has been processed.
            batch_n = len(batch_x)
            if batch_n > 0:
                pad_width = FLAGS.batch_size - batch_n
                batch_x = np.pad(batch_x, ((0, pad_width), (0, 0)),
                                 mode='wrap')
                seq_len = np.pad(seq_len, ((0, pad_width)), mode='wrap')
                logits_idx = np.pad(logits_idx, (0, pad_width),
                                    mode='constant',
                                    constant_values=-1)
                logits_fn = np.pad(logits_fn, (0, pad_width),
                                   mode='constant',
                                   constant_values='')
                sess.run(logits_enqueue,
                         feed_dict={
                             x: batch_x,
                             seq_length:
                             np.round(seq_len / ratio).astype(np.int32),
                             training: False,
                             logits_index: logits_idx,
                             logits_fname: logits_fn,
                         })
            sess.run(logits_queue_close)


#

        worker = threading.Thread(target=worker_fn, args=())
        worker.setDaemon(True)
        worker.start()

        val = defaultdict(
            dict)  # We could read vals out of order, that's why it's a dict
        for f_i, name in enumerate(file_list):
            start_time = time.time()
            if not name.endswith('.signal'):
                continue
            file_pre = os.path.splitext(name)[0]
            input_path = os.path.join(file_dir, name)
            if FLAGS.mode == 'rna':
                eval_data = read_data_for_eval(input_path,
                                               FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
            else:
                eval_data = read_data_for_eval(input_path,
                                               FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
            reads_n = eval_data.reads_n
            pbars.update(1, total=reads_n, progress=0)
            pbars.update_bar()
            reading_time = time.time() - start_time
            reads = list()
            if 'total_count' not in val[name].keys():
                val[name]['total_count'] = 0
            if 'index_list' not in val[name].keys():
                val[name]['index_list'] = []
            while True:
                l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size])
                if val[name]['total_count'] == reads_n:
                    pbars.update(1, progress=val[name]['total_count'])
                    break
                decode_ops = [
                    decoded_fname_op, decode_idx_op, decode_predict_op,
                    decode_prob_op
                ]
                decoded_fname, i, predict_val, logits_prob = sess.run(
                    decode_ops, feed_dict={training: False})
                decoded_fname = np.asarray(
                    [x.decode("UTF-8") for x in decoded_fname])
                ##Have difficulties integrate it into the tensorflow graph, as the number of file names in a batch is variable.
                ##And for loop can't be implemented as the eager execution is disabled due to the use of queue.
                uniq_fname, uniq_fn_idx = np.unique(decoded_fname,
                                                    return_index=True)
                for fn_idx, fn in enumerate(uniq_fname):
                    i = uniq_fn_idx[fn_idx]
                    if fn != '':
                        occurance = np.where(decoded_fname == fn)[0]
                        start = occurance[0]
                        end = occurance[-1] + 1
                        assert (len(occurance) == end - start)
                        if 'total_count' not in val[fn].keys():
                            val[fn]['total_count'] = 0
                        if 'index_list' not in val[fn].keys():
                            val[fn]['index_list'] = []
                        val[fn]['total_count'] += (end - start)
                        val[fn]['index_list'].append(i)
                        sliced_sparse = slice_ctc_decoding_result(
                            predict_val, start, end)
                        val[fn][i] = (sliced_sparse,
                                      logits_prob[decoded_fname == fn])
                pbars.update(1, progress=val[name]['total_count'])
                pbars.update_bar()

            pbars.update(3, progress=f_i + 1)
            pbars.update_bar()
            qs_list = np.empty((0, 1), dtype=np.float)
            qs_string = None
            for i in np.sort(val[name]['index_list']):
                predict_val, logits_prob = val[name][i]
                predict_read, unique = sparse2dense(predict_val)
                predict_read = predict_read[0]
                unique = unique[0]

                if FLAGS.extension == 'fastq':
                    logits_prob = logits_prob[unique]
                if FLAGS.extension == 'fastq':
                    qs_list = np.concatenate((qs_list, logits_prob))
                reads += predict_read
            val.pop(name)  # Release the memory
            basecall_time = time.time() - start_time
            bpreads = [index2base(read) for read in reads]
            js_ratio = FLAGS.jump / FLAGS.segment_len
            kernal = get_assembler_kernal(FLAGS.jump, FLAGS.segment_len)
            if FLAGS.extension == 'fastq':
                consensus, qs_consensus = simple_assembly_qs(bpreads,
                                                             qs_list,
                                                             js_ratio,
                                                             kernal=kernal)
                qs_string = qs(consensus, qs_consensus)
            else:
                consensus = simple_assembly(bpreads, js_ratio, kernal=kernal)
            c_bpread = index2base(np.argmax(consensus, axis=0))
            assembly_time = time.time() - start_time
            list_of_time = [
                start_time, reading_time, basecall_time, assembly_time
            ]
            write_output(bpreads,
                         c_bpread,
                         list_of_time,
                         file_pre,
                         concise=FLAGS.concise,
                         suffix=FLAGS.extension,
                         q_score=qs_string,
                         global_setting=FLAGS)
    pbars.end()
Example #5
0
def chiron_train():
    training = tf.placeholder(tf.bool)
    global_step = tf.get_variable('global_step',
                                  trainable=False,
                                  shape=(),
                                  dtype=tf.int32,
                                  initializer=tf.zeros_initializer())
    x = tf.placeholder(tf.float32,
                       shape=[FLAGS.batch_size, FLAGS.sequence_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    y_indexs = tf.placeholder(tf.int64)
    y_values = tf.placeholder(tf.int32)
    y_shape = tf.placeholder(tf.int64)
    y = tf.SparseTensor(y_indexs, y_values, y_shape)
    default_config = os.path.join(FLAGS.log_dir, FLAGS.model_name,
                                  'model.json')

    if FLAGS.retrain:
        if os.path.isfile(default_config):
            config_file = default_config
        else:
            raise ValueError(
                "Model Json file has not been found in model log directory")
    else:
        config_file = FLAGS.configure

    config = model.read_config(config_file)

    logits, ratio = model.inference(x,
                                    seq_length,
                                    training,
                                    FLAGS.sequence_len,
                                    configure=config)
    ctc_loss = model.loss(logits, seq_length, y)
    opt = model.train_opt(FLAGS.step_rate,
                          FLAGS.max_steps,
                          global_step=global_step,
                          opt_name=config['opt_method'])
    step = opt.minimize(ctc_loss, global_step=global_step)
    error = model.prediction(logits, seq_length, y)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

    if FLAGS.retrain == False:
        sess.run(init)
        print("Model init finished, begin loading data. \n")
    else:
        saver.restore(
            sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name))
        print("Model loaded finished, begin loading data. \n")
    summary_writer = tf.summary.FileWriter(
        FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph)
    model.save_model(default_config, config)

    train_ds, valid_ds = generate_train_valid_datasets()
    start = time.time()

    for i in range(FLAGS.max_steps):
        batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size)
        indxs, values, shape = batch_y
        feed_dict = {
            x: batch_x,
            seq_length: seq_len / ratio,
            y_indexs: indxs,
            y_values: values,
            y_shape: shape,
            training: True
        }
        loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict)
        if i % 10 == 0:
            global_step_val = tf.train.global_step(sess, global_step)
            valid_x, valid_len, valid_y = valid_ds.next_batch(FLAGS.batch_size)
            indxs, values, shape = valid_y
            feed_dict = {
                x: valid_x,
                seq_length: valid_len / ratio,
                y_indexs: indxs,
                y_values: values,
                y_shape: shape,
                training: True
            }
            error_val = sess.run(error, feed_dict=feed_dict)
            end = time.time()
            print(
            "Step %d/%d Epoch %d, batch number %d, train_loss: %5.3f validate_edit_distance: %5.3f Elapsed Time/step: %5.3f" \
            % (i, FLAGS.max_steps, train_ds.epochs_completed,
               train_ds.index_in_epoch, loss_val, error_val,
               (end - start) / (i + 1)))
            saver.save(sess,
                       FLAGS.log_dir + FLAGS.model_name + '/model.ckpt',
                       global_step=global_step_val)
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str,
                                       global_step=global_step_val)
            summary_writer.flush()
    global_step_val = tf.train.global_step(sess, global_step)
    print("Model %s saved." % (FLAGS.log_dir + FLAGS.model_name))
    print("Reads number %d" % (train_ds.reads_n))
    saver.save(sess,
               FLAGS.log_dir + FLAGS.model_name + '/final.ckpt',
               global_step=global_step_val)
Example #6
0
def train():
    default_config = os.path.join(FLAGS.log_dir, FLAGS.model_name,
                                  'model.json')
    if FLAGS.retrain:
        if os.path.isfile(default_config):
            config_file = default_config
        else:
            raise ValueError(
                "Model Json file has not been found in model log directory")
    else:
        config_file = FLAGS.configure
    config = model.read_config(config_file)
    print("Begin training using following setting:")
    with open(os.path.join(FLAGS.log_dir, FLAGS.model_name, 'train_config'),
              'w+') as log_f:
        for pro in dir(FLAGS):
            if not pro.startswith('_'):
                print("%s:%s" % (pro, getattr(FLAGS, pro)))
                log_f.write("%s:%s\n" % (pro, getattr(FLAGS, pro)))
    net = compile_train_graph(config, FLAGS)
    sess = tf.Session(
        config=tf.ConfigProto(inter_op_parallelism_threads=FLAGS.threads,
                              intra_op_parallelism_threads=FLAGS.threads,
                              allow_soft_placement=True))
    if FLAGS.retrain == False:
        sess.run(net.init)
        print("Model init finished, begin loading data. \n")
    else:
        net.saver.restore(
            sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name))
        print("Model loaded finished, begin loading data. \n")
    summary_writer = tf.summary.FileWriter(
        FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph)
    model.save_model(default_config, config)
    train_ds, valid_ds = generate_train_valid_datasets(
        initial_offset=DEFAULT_OFFSET)
    start = time.time()
    resample_n = 0
    for i in range(FLAGS.max_steps):
        if FLAGS.resample_after_epoch == 0:
            pass
        elif train_ds.epochs_completed >= FLAGS.resample_after_epoch:
            train_ds, valid_ds = generate_train_valid_datasets(
                initial_offset=resample_n * FLAGS.offset_increment +
                DEFAULT_OFFSET)
        batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size)
        indxs, values, shape = batch_y
        feed_dict = {
            net.x: batch_x,
            net.seq_length: seq_len / net.ratio,
            net.y_indexs: indxs,
            net.y_values: values,
            net.y_shape: shape,
            net.training: True
        }
        loss_val, _ = sess.run([net.ctc_loss, net.step], feed_dict=feed_dict)
        if i % 10 == 0:
            global_step_val = tf.train.global_step(sess, net.global_step)
            valid_x, valid_len, valid_y = valid_ds.next_batch(FLAGS.batch_size)
            indxs, values, shape = valid_y
            feed_dict = {
                net.x: valid_x,
                net.seq_length: valid_len / net.ratio,
                net.y_indexs: indxs,
                net.y_values: values,
                net.y_shape: shape,
                net.training: True
            }
            error_val = sess.run(net.error, feed_dict=feed_dict)
            #            x_val,errors_val,y_predict,y = sess.run([x,errors,y_,y],feed_dict = feed_dict)
            #            predict_seq,_ = sparse2dense([y_predict,0])
            #            true_seq,_ = sparse2dense([[y],0])
            end = time.time()
            print(
            "Step %d/%d Epoch %d, batch number %d, train_loss: %5.3f validate_edit_distance: %5.3f Elapsed Time/step: %5.3f" \
            % (i, FLAGS.max_steps, train_ds.epochs_completed,
               train_ds.index_in_epoch, loss_val, error_val,
               (end - start) / (i + 1)))
            net.saver.save(sess,
                           FLAGS.log_dir + FLAGS.model_name + '/model.ckpt',
                           global_step=global_step_val)
            summary_str = sess.run(net.summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str,
                                       global_step=global_step_val)
            summary_writer.flush()
    global_step_val = tf.train.global_step(sess, net.global_step)
    print("Model %s saved." % (FLAGS.log_dir + FLAGS.model_name))
    print("Reads number %d" % (train_ds.reads_n))
    net.saver.save(sess,
                   FLAGS.log_dir + FLAGS.model_name + '/final.ckpt',
                   global_step=global_step_val)
Example #7
0
def evaluation():
    logger = logging.getLogger(__name__)
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.model, 'model.json')
    model_configure = chiron_model.read_config(config_path)

    logits, ratio = chiron_model.inference(x,
                                           seq_length,
                                           training=training,
                                           full_sequence_len=FLAGS.segment_len,
                                           configure=model_configure)
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    tqdm.monitor_interval = 0
    logits_index = tf.placeholder(tf.int32, shape=())
    logits_fname = tf.placeholder(tf.string, shape=())
    logits_queue = tf.FIFOQueue(
        capacity=1000,
        dtypes=[tf.float32, tf.string, tf.int32, tf.int32],
        shapes=[
            logits.shape, logits_fname.shape, logits_index.shape,
            seq_length.shape
        ])
    logits_queue_size = logits_queue.size()
    logits_enqueue = logits_queue.enqueue(
        (logits, logits_fname, logits_index, seq_length))
    logits_queue_close = logits_queue.close()
    ### Decoding logits into bases
    decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(
        logits_queue)
    saver = tf.train.Saver()
    with tf.train.MonitoredSession(
            session_creator=tf.train.ChiefSessionCreator(
                config=config)) as sess:
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model))
        if os.path.isdir(FLAGS.input):
            file_list = os.listdir(FLAGS.input)
            file_dir = FLAGS.input
        else:
            file_list = [os.path.basename(FLAGS.input)]
            file_dir = os.path.abspath(
                os.path.join(FLAGS.input, os.path.pardir))

        if not os.path.exists(FLAGS.output):
            os.makedirs(FLAGS.output)
        if not os.path.exists(os.path.join(FLAGS.output, 'segments')):
            os.makedirs(os.path.join(FLAGS.output, 'segments'))
        if not os.path.exists(os.path.join(FLAGS.output, 'result')):
            os.makedirs(os.path.join(FLAGS.output, 'result'))
        if not os.path.exists(os.path.join(FLAGS.output, 'meta')):
            os.makedirs(os.path.join(FLAGS.output, 'meta'))

        def worker_fn():
            for name in tqdm(file_list, desc="Logits inferencing.",
                             position=0):
                if not name.endswith('.signal'):
                    continue
                input_path = os.path.join(file_dir, name)
                eval_data = read_data_for_eval(input_path,
                                               FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
                reads_n = eval_data.reads_n
                for i in trange(0,
                                reads_n,
                                FLAGS.batch_size,
                                desc="Logits inferencing",
                                position=1):
                    batch_x, seq_len, _ = eval_data.next_batch(
                        FLAGS.batch_size, shuffle=False, sig_norm=False)
                    batch_x = np.pad(batch_x,
                                     ((0, FLAGS.batch_size - len(batch_x)),
                                      (0, 0)),
                                     mode='constant')
                    seq_len = np.pad(seq_len,
                                     ((0, FLAGS.batch_size - len(seq_len))),
                                     mode='constant')
                    feed_dict = {
                        x: batch_x,
                        seq_length: seq_len,
                        training: False,
                        logits_index: i,
                        logits_fname: name,
                    }
                    sess.run(logits_enqueue, feed_dict=feed_dict)
            sess.run(logits_queue_close)

        def run_listener(write_lock):
            # This function is used to solve the error when tqdm is used inside thread
            # https://github.com/tqdm/tqdm/issues/323
            tqdm.set_lock(write_lock)
            worker_fn()

        write_lock = threading.Lock()
        worker = threading.Thread(target=run_listener, args=(write_lock, ))
        worker.setDaemon(True)
        worker.start()

        val = defaultdict(
            dict)  # We could read vals out of order, that's why it's a dict
        for name in tqdm(file_list, desc="CTC decoding.", position=2):
            start_time = time.time()
            if not name.endswith('.signal'):
                continue
            file_pre = os.path.splitext(name)[0]
            input_path = os.path.join(file_dir, name)
            if FLAGS.mode == 'rna':
                eval_data = read_data_for_eval(input_path,
                                               FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump,
                                               reverse=True)
            else:
                eval_data = read_data_for_eval(input_path,
                                               FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
            reads_n = eval_data.reads_n
            reading_time = time.time() - start_time
            reads = list()

            N = len(range(0, reads_n, FLAGS.batch_size))
            with tqdm(total=reads_n, desc="ctc decoding", position=3) as pbar:
                while True:
                    options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()
                    l_sz, d_sz = sess.run(
                        [logits_queue_size, decode_queue_size],
                        options=options,
                        run_metadata=run_metadata)
                    pbar.set_postfix(logits_q=l_sz,
                                     decoded_q=d_sz,
                                     refresh=False)
                    decode_ops = [
                        decoded_fname_op, decode_idx_op, decode_predict_op,
                        decode_prob_op
                    ]
                    decoded_fname, i, predict_val, logits_prob = sess.run(
                        decode_ops,
                        feed_dict={training: False},
                        options=options,
                        run_metadata=run_metadata)
                    decoded_fname = decoded_fname.decode("UTF-8")
                    val[decoded_fname][i] = (predict_val, logits_prob)

                    fetched_timeline = timeline.Timeline(
                        run_metadata.step_stats)
                    chrome_trace = fetched_timeline.generate_chrome_trace_format(
                    )
                    with open('timeline_02_step_%d.json' % i, 'w') as f:
                        f.write(chrome_trace)

                    if decoded_fname == name:
                        decoded_cnt = len(val[name])
                        pbar.update(
                            min(reads_n, decoded_cnt * FLAGS.batch_size) -
                            (decoded_cnt - 1) * FLAGS.batch_size)
                        if decoded_cnt == N:
                            break

            qs_list = np.empty((0, 1), dtype=np.float)
            qs_string = None
            for i in trange(0,
                            reads_n,
                            FLAGS.batch_size,
                            desc="Output",
                            position=4):
                predict_val, logits_prob = val[name][i]
                predict_read, unique = sparse2dense(predict_val)
                predict_read = predict_read[0]
                unique = unique[0]

                if FLAGS.extension == 'fastq':
                    logits_prob = logits_prob[unique]
                if i + FLAGS.batch_size > reads_n:
                    predict_read = predict_read[:reads_n - i]
                    if FLAGS.extension == 'fastq':
                        logits_prob = logits_prob[:reads_n - i]
                if FLAGS.extension == 'fastq':
                    qs_list = np.concatenate((qs_list, logits_prob))
                reads += predict_read
            val.pop(name)  # Release the memory

            # tqdm.write("[%s] Segment reads base calling finished, begin to assembly. %5.2f seconds" % (name, time.time() - start_time))
            basecall_time = time.time() - start_time
            bpreads = [index2base(read) for read in reads]
            if FLAGS.extension == 'fastq':
                consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list)
                qs_string = qs(consensus, qs_consensus)
            else:
                consensus = simple_assembly(bpreads)
            c_bpread = index2base(np.argmax(consensus, axis=0))
            np.set_printoptions(threshold=np.nan)
            assembly_time = time.time() - start_time
            # tqdm.write("[%s] Assembly finished, begin output. %5.2f seconds" % (name, time.time() - start_time))
            list_of_time = [
                start_time, reading_time, basecall_time, assembly_time
            ]
            write_output(bpreads,
                         c_bpread,
                         list_of_time,
                         file_pre,
                         concise=FLAGS.concise,
                         suffix=FLAGS.extension,
                         q_score=qs_string)
Example #8
0
def evaluation():
    config_path = os.path.join(FLAGS.model, 'model.json')
    model_configure = chiron_model.read_config(config_path)
    net = build_eval_graph(model_configure)
    val = defaultdict(
        dict)  # We could read vals out of order, that's why it's a dict
    for f_i, name in enumerate(net.file_list):
        start_time = time.time()
        if (not name.endswith('.signal')) and (not name.endswith('.fast5')):
            continue
        file_pre = os.path.splitext(name)[0]
        input_path = os.path.join(net.file_dir, name)
        ###Other mode (like methylation) may use different read method.
        eval_data = read_data_for_eval(input_path,
                                       FLAGS.start,
                                       seg_length=FLAGS.segment_len,
                                       step=FLAGS.jump,
                                       reverse_fast5=FLAGS.reverse_fast5)
        reads_n = eval_data.reads_n
        net.pbars.update(1, total=reads_n, progress=0)
        net.pbars.update_bar()
        reading_time = time.time() - start_time
        reads = list()
        if 'total_count' not in val[name].keys():
            val[name]['total_count'] = 0
        if 'index_list' not in val[name].keys():
            val[name]['index_list'] = []
        while True:
            l_sz, d_sz = net.sess.run(
                [net.logits_queue_size, net.decode_queue_size])
            if val[name]['total_count'] == reads_n:
                net.pbars.update(1, progress=val[name]['total_count'])
                break
            decode_ops = [
                net.decoded_fname_op, net.decode_idx_op, net.decode_predict_op,
                net.decode_prob_op
            ]
            decoded_fname, i, predict_val, logits_prob = net.sess.run(
                decode_ops, feed_dict={net.training: False})
            decoded_fname = np.asarray(
                [x.decode("UTF-8") for x in decoded_fname])
            ##Have difficulties integrate it into the tensorflow graph, as the number of file names in a batch is uncertain.
            ##And for loop can't be implemented as the eager execution is disabled due to the use of queue.
            uniq_fname, uniq_fn_idx = np.unique(decoded_fname,
                                                return_index=True)
            for fn_idx, fn in enumerate(uniq_fname):
                i = uniq_fn_idx[fn_idx]
                if fn != '':
                    occurance = np.where(decoded_fname == fn)[0]
                    start = occurance[0]
                    end = occurance[-1] + 1
                    assert (len(occurance) == end - start)
                    if 'total_count' not in val[fn].keys():
                        val[fn]['total_count'] = 0
                    if 'index_list' not in val[fn].keys():
                        val[fn]['index_list'] = []
                    val[fn]['total_count'] += (end - start)
                    val[fn]['index_list'].append(i)
                    sliced_sparse = slice_ctc_decoding_result(
                        predict_val, start, end)
                    val[fn][i] = (sliced_sparse,
                                  logits_prob[decoded_fname == fn])
            net.pbars.update(1, progress=val[name]['total_count'])
            net.pbars.update_bar()

        net.pbars.update(3, progress=f_i + 1)
        net.pbars.update_bar()
        qs_list = np.empty((0, 1), dtype=np.float)
        qs_string = None
        for i in np.sort(val[name]['index_list']):
            predict_val, logits_prob = val[name][i]
            predict_read, unique = sparse2dense(predict_val)
            predict_read = predict_read[0]
            unique = unique[0]

            if FLAGS.extension == 'fastq':
                logits_prob = logits_prob[unique]
            if FLAGS.extension == 'fastq':
                qs_list = np.concatenate((qs_list, logits_prob))
            reads += predict_read
        val.pop(name)  # Release the memory
        basecall_time = time.time() - start_time
        bpreads = [index2base(read) for read in reads]
        js_ratio = FLAGS.jump / FLAGS.segment_len
        kernal = get_assembler_kernal(FLAGS.jump, FLAGS.segment_len)
        if FLAGS.extension == 'fastq':
            consensus, qs_consensus = simple_assembly_qs(bpreads,
                                                         qs_list,
                                                         js_ratio,
                                                         kernal=kernal)
            qs_string = qs(consensus, qs_consensus)
        else:
            consensus = simple_assembly(bpreads, js_ratio, kernal=kernal)
        c_bpread = index2base(np.argmax(consensus, axis=0))
        assembly_time = time.time() - start_time
        list_of_time = [start_time, reading_time, basecall_time, assembly_time]
        write_output(bpreads,
                     c_bpread,
                     list_of_time,
                     file_pre,
                     concise=FLAGS.concise,
                     suffix=FLAGS.extension,
                     q_score=qs_string,
                     global_setting=FLAGS)
    net.pbars.end()
def train(hparams):
    """Main training function.
    This will train a Neural Network with the given dataset.

    Args:
        hparams: hyper parameter for training the neural network
            data-dir: String, the path of the data(binary batch files) directory.
            log-dir: String, the path to save the trained model.
            sequence-len: Int, length of input signal.
            batch-size: Int.
            step-rate: Float, step rate of the optimizer.
            max-steps: Int, max training steps.
            kmer: Int, size of the dna kmer.
            model-name: String, model will be saved at log-dir/model-name.
            retrain: Boolean, if True, the model will be reload from log-dir/model-name.

    """
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        training = tf.placeholder(tf.bool)
        global_step = tf.get_variable('global_step',
                                      trainable=False,
                                      shape=(),
                                      dtype=tf.int32,
                                      initializer=tf.zeros_initializer())

        opt = model.train_opt(hparams.step_rate,
                              hparams.max_steps,
                              global_step=global_step)
        x, seq_length, train_labels = inputs(hparams.data_dir,
                                             int(hparams.batch_size *
                                                 hparams.ngpus),
                                             for_valid=False)
        split_y = tf.split(train_labels, hparams.ngpus, axis=0)
        split_seq_length = tf.split(seq_length, hparams.ngpus, axis=0)
        split_x = tf.split(x, hparams.ngpus, axis=0)
        tower_grads = []
        default_config = os.path.join(hparams.log_dir, hparams.model_name,
                                      'model.json')
        if hparams.retrain:
            if os.path.isfile(default_config):
                config_file = default_config
            else:
                raise ValueError(
                    "Model Json file has not been found in model log directory"
                )
        else:
            config_file = hparams.configure
        config = model.read_config(config_file)
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(hparams.ngpus):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' % ('gpu_tower', i)) as scope:
                        loss, error = tower_loss(
                            scope,
                            split_x[i],
                            split_seq_length[i],
                            split_y[i],
                            full_seq_len=hparams.sequence_len,
                            config=config)
                        tf.get_variable_scope().reuse_variables()
                        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                      scope)
                        grads = opt.compute_gradients(loss)
                        tower_grads.append(grads)
        grads = average_gradients(tower_grads)
        for grad, var in grads:
            if grad is not None:
                summaries.append(
                    tf.summary.histogram(var.op.name + '/gradients', grad))
        for var in tf.trainable_variables():
            summaries.append(tf.summary.histogram(var.op.name, var))
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
        var_averages = tf.train.ExponentialMovingAverage(
            decay=model.MOVING_AVERAGE_DECAY)
        var_averages_op = var_averages.apply(tf.trainable_variables())
        train_op = tf.group(apply_gradient_op, var_averages_op)
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()
        summary = tf.summary.merge_all()

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False))
        model.save_model(default_config, config)
        if not hparams.retrain:
            sess.run(init)
            print("Model init finished, begin training. \n")
        else:
            saver.restore(
                sess,
                tf.train.latest_checkpoint(hparams.log_dir +
                                           hparams.model_name))
            print("Model loaded finished, begin training. \n")
        summary_writer = tf.summary.FileWriter(
            hparams.log_dir + hparams.model_name + '/summary/', sess.graph)
        _ = tf.train.start_queue_runners(sess=sess)

        start = time.time()
        for i in range(hparams.max_steps):
            feed_dict = {training: True}
            loss_val, _ = sess.run([loss, train_op], feed_dict=feed_dict)
            if i % 10 == 0:
                global_step_val = tf.train.global_step(sess, global_step)
                feed_dict = {training: True}
                error_val = sess.run(error, feed_dict=feed_dict)
                end = time.time()
                print(
                    "Step %d/%d ,  loss: %5.3f edit_distance: %5.3f Elapsed Time/batch: %5.3f" \
                    % (i, hparams.max_steps, loss_val, error_val,
                    (end - start) / (i + 1)))
                saver.save(sess,
                           hparams.log_dir + hparams.model_name +
                           '/model.ckpt',
                           global_step=global_step_val)
                summary_str = sess.run(summary, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str,
                                           global_step=global_step_val)
                summary_writer.flush()
        global_step_val = tf.train.global_step(sess, global_step)
        print("Model %s saved." % (hparams.log_dir + hparams.model_name))
        saver.save(sess,
                   hparams.log_dir + hparams.model_name + '/final.ckpt',
                   global_step=global_step_val)
Example #10
0
def evaluation():
    pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"])
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.model,'model.json')
    model_configure = chiron_model.read_config(config_path)

    logits, ratio = chiron_model.inference(
                                    x, 
                                    seq_length, 
                                    training=training,
                                    full_sequence_len = FLAGS.segment_len,
                                    configure = model_configure)
    config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    logits_index = tf.placeholder(tf.int32, shape=())
    logits_fname = tf.placeholder(tf.string, shape=())
    logits_queue = tf.FIFOQueue(
        capacity=1000,
        dtypes=[tf.float32, tf.string, tf.int32, tf.int32],
        shapes=[logits.shape, logits_fname.shape, logits_index.shape, seq_length.shape]
    )
    logits_queue_size = logits_queue.size()
    logits_enqueue = logits_queue.enqueue((logits, logits_fname, logits_index, seq_length))
    logits_queue_close = logits_queue.close()
    ### Decoding logits into bases
    decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(logits_queue)
    saver = tf.train.Saver()
    with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(config=config)) as sess:
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model))
        if os.path.isdir(FLAGS.input):
            file_list = os.listdir(FLAGS.input)
            file_dir = FLAGS.input
        else:
            file_list = [os.path.basename(FLAGS.input)]
            file_dir = os.path.abspath(
                os.path.join(FLAGS.input, os.path.pardir))
        file_n = len(file_list)
        pbars.update(2,total = file_n)
        pbars.update(3,total = file_n)
        if not os.path.exists(FLAGS.output):
            os.makedirs(FLAGS.output)
        if not os.path.exists(os.path.join(FLAGS.output, 'segments')):
            os.makedirs(os.path.join(FLAGS.output, 'segments'))
        if not os.path.exists(os.path.join(FLAGS.output, 'result')):
            os.makedirs(os.path.join(FLAGS.output, 'result'))
        if not os.path.exists(os.path.join(FLAGS.output, 'meta')):
            os.makedirs(os.path.join(FLAGS.output, 'meta'))
        def worker_fn():
            for f_i, name in enumerate(file_list):
                if not name.endswith('.signal'):
                    continue
                input_path = os.path.join(file_dir, name)
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
                reads_n = eval_data.reads_n
                pbars.update(0,total = reads_n,progress = 0)
                pbars.update_bar()
                for i in range(0, reads_n, FLAGS.batch_size):
                    batch_x, seq_len, _ = eval_data.next_batch(
                        FLAGS.batch_size, shuffle=False, sig_norm=False)
                    batch_x = np.pad(
                        batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant')
                    seq_len = np.pad(
                        seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant')
                    feed_dict = {
                        x: batch_x,
                        seq_length: seq_len,
                        training: False,
                        logits_index:i,
                        logits_fname: name,
                    }
Example #11
0
def train(hparam):
    """Main training function.
    This will train a Neural Network with the given dataset.

    Args:
        hparam: hyper parameter for training the neural network
            data_dir: String, the path of the data(binary batch files) directory.
            log-dir: String, the path to save the trained model.
            sequence-len: Int, length of input signal.
            batch-size: Int.
            step-rate: Float, step rate of the optimizer.
            max-steps: Int, max training steps.
            kmer: Int, size of the dna kmer.
            model-name: String, model will be saved at log-dir/model-name.
            retrain: Boolean, if True, the model will be reload from log-dir/model-name.

    """
    training = tf.placeholder(tf.bool)
    global_step = tf.get_variable('global_step',
                                  trainable=False,
                                  shape=(),
                                  dtype=tf.int32,
                                  initializer=tf.zeros_initializer())

    x, seq_length, train_labels = inputs(hparam.data_dir,
                                         hparam.batch_size,
                                         hparam.sequence_len,
                                         for_valid=False)
    y = dense2sparse(train_labels)
    default_config = os.path.join(hparam.log_dir, hparam.model_name,
                                  'model.json')
    if hparam.retrain:
        if os.path.isfile(default_config):
            config_file = default_config
        else:
            raise ValueError(
                "Model Json file has not been found in model log directory")
    else:
        config_file = hparam.configure
    config = model.read_config(config_file)
    logits, ratio = model.inference(x,
                                    seq_length,
                                    training,
                                    hparam.sequence_len,
                                    configure=config,
                                    apply_ratio=True)
    seq_length = tf.cast(tf.ceil(tf.cast(seq_length, tf.float32) / ratio),
                         tf.int32)
    ctc_loss = model.loss(logits, seq_length, y)
    opt = model.train_opt(hparam.step_rate,
                          hparam.max_steps,
                          global_step=global_step)
    step = opt.minimize(ctc_loss, global_step=global_step)
    error = model.prediction(logits, seq_length, y)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    model.save_model(default_config, config)
    if not hparam.retrain:
        sess.run(init)
        print("Model init finished, begin training. \n")
    else:
        saver.restore(
            sess,
            tf.train.latest_checkpoint(hparam.log_dir + hparam.model_name))
        print("Model loaded finished, begin training. \n")
    summary_writer = tf.summary.FileWriter(
        hparam.log_dir + hparam.model_name + '/summary/', sess.graph)
    _ = tf.train.start_queue_runners(sess=sess)

    start = time.time()
    for i in range(hparam.max_steps):
        feed_dict = {training: True}
        loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict)
        if i % 10 == 0:
            global_step_val = tf.train.global_step(sess, global_step)
            feed_dict = {training: True}
            error_val = sess.run(error, feed_dict=feed_dict)
            end = time.time()
            print(
                "Step %d/%d ,  loss: %5.3f edit_distance: %5.3f Elapsed Time/batch: %5.3f" \
                % (i, hparam.max_steps, loss_val, error_val,
                   (end - start) / (i + 1)))
            saver.save(sess,
                       hparam.log_dir + hparam.model_name + '/model.ckpt',
                       global_step=global_step_val)
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str,
                                       global_step=global_step_val)
            summary_writer.flush()
    global_step_val = tf.train.global_step(sess, global_step)
    print("Model %s saved." % (hparam.log_dir + hparam.model_name))
    saver.save(sess,
               hparam.log_dir + hparam.model_name + '/final.ckpt',
               global_step=global_step_val)
Example #12
0
def evaluation():
    pbars = multi_pbars(["Logits(batches)","ctc(batches)","logits(files)","ctc(files)"])
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.model,'model.json')
    model_configure = chiron_model.read_config(config_path)

    logits, ratio = chiron_model.inference(
                                    x, 
                                    seq_length, 
                                    training=training,
                                    full_sequence_len = FLAGS.segment_len,
                                    configure = model_configure)
    config = tf.ConfigProto(allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True
    logits_index = tf.placeholder(tf.int32, shape=())
    logits_fname = tf.placeholder(tf.string, shape=())
    logits_queue = tf.FIFOQueue(
        capacity=1000,
        dtypes=[tf.float32, tf.string, tf.int32, tf.int32],
        shapes=[logits.shape,logits_fname.shape,logits_index.shape, seq_length.shape]
    )
    logits_queue_size = logits_queue.size()
    logits_enqueue = logits_queue.enqueue((logits, logits_fname, logits_index, seq_length))
    logits_queue_close = logits_queue.close()
    ### Decoding logits into bases
    decode_predict_op, decode_prob_op, decoded_fname_op, decode_idx_op, decode_queue_size = decoding_queue(logits_queue)
    saver = tf.train.Saver()
    with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(config=config)) as sess:
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model))
        if os.path.isdir(FLAGS.input):
            file_list = os.listdir(FLAGS.input)
            file_dir = FLAGS.input
        else:
            file_list = [os.path.basename(FLAGS.input)]
            file_dir = os.path.abspath(
                os.path.join(FLAGS.input, os.path.pardir))
        file_n = len(file_list)
        pbars.update(2,total = file_n)
        pbars.update(3,total = file_n)
        if not os.path.exists(FLAGS.output):
            os.makedirs(FLAGS.output)
        if not os.path.exists(os.path.join(FLAGS.output, 'segments')):
            os.makedirs(os.path.join(FLAGS.output, 'segments'))
        if not os.path.exists(os.path.join(FLAGS.output, 'result')):
            os.makedirs(os.path.join(FLAGS.output, 'result'))
        if not os.path.exists(os.path.join(FLAGS.output, 'meta')):
            os.makedirs(os.path.join(FLAGS.output, 'meta'))
        def worker_fn():
            for f_i, name in enumerate(file_list):
                if not name.endswith('.signal'):
                    continue
                input_path = os.path.join(file_dir, name)
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                               seg_length=FLAGS.segment_len,
                                               step=FLAGS.jump)
                reads_n = eval_data.reads_n
                pbars.update(0,total = reads_n,progress = 0)
                pbars.update_bar()
                for i in range(0, reads_n, FLAGS.batch_size):
                    batch_x, seq_len, _ = eval_data.next_batch(
                        FLAGS.batch_size, shuffle=False, sig_norm=False)
                    batch_x = np.pad(
                        batch_x, ((0, FLAGS.batch_size - len(batch_x)), (0, 0)), mode='constant')
                    seq_len = np.pad(
                        seq_len, ((0, FLAGS.batch_size - len(seq_len))), mode='constant')
                    feed_dict = {
                        x: batch_x,
                        seq_length: np.round(seq_len/ratio).astype(np.int32),
                        training: False,
                        logits_index:i,
                        logits_fname: name,
                    }
                    sess.run(logits_enqueue,feed_dict=feed_dict)
                    pbars.update(0,progress=i+FLAGS.batch_size)
                    pbars.update_bar()
                pbars.update(2,progress = f_i+1)
                pbars.update_bar()
            sess.run(logits_queue_close)

        worker = threading.Thread(target=worker_fn,args=() )
        worker.setDaemon(True)
        worker.start()

        val = defaultdict(dict)  # We could read vals out of order, that's why it's a dict
        for f_i, name in enumerate(file_list):
            start_time = time.time()
            if not name.endswith('.signal'):
                continue
            file_pre = os.path.splitext(name)[0]
            input_path = os.path.join(file_dir, name)
            if FLAGS.mode == 'rna':
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                           seg_length=FLAGS.segment_len,
                                           step=FLAGS.jump)
            else:
                eval_data = read_data_for_eval(input_path, FLAGS.start,
                                           seg_length=FLAGS.segment_len,
                                           step=FLAGS.jump)
            reads_n = eval_data.reads_n
            pbars.update(1,total = reads_n,progress = 0)
            pbars.update_bar()
            reading_time = time.time() - start_time
            reads = list()

            N = len(range(0, reads_n, FLAGS.batch_size))
            while True:
                l_sz, d_sz = sess.run([logits_queue_size, decode_queue_size])
                decode_ops = [decoded_fname_op, decode_idx_op, decode_predict_op, decode_prob_op]
                decoded_fname, i, predict_val, logits_prob = sess.run(decode_ops, feed_dict={training: False})
                decoded_fname = decoded_fname.decode("UTF-8")
                val[decoded_fname][i] = (predict_val, logits_prob)               
                pbars.update(1,progress = len(val[name])*FLAGS.batch_size)
                pbars.update_bar()
                if len(val[name]) == N:
                    break

            pbars.update(3,progress = f_i+1)
            pbars.update_bar()
            qs_list = np.empty((0, 1), dtype=np.float)
            qs_string = None
            for i in range(0, reads_n, FLAGS.batch_size):
                predict_val, logits_prob = val[name][i]
                predict_read, unique = sparse2dense(predict_val)
                predict_read = predict_read[0]
                unique = unique[0]

                if FLAGS.extension == 'fastq':
                    logits_prob = logits_prob[unique]
                if i + FLAGS.batch_size > reads_n:
                    predict_read = predict_read[:reads_n - i]
                    if FLAGS.extension == 'fastq':
                        logits_prob = logits_prob[:reads_n - i]
                if FLAGS.extension == 'fastq':
                    qs_list = np.concatenate((qs_list, logits_prob))
                reads += predict_read
            val.pop(name)  # Release the memory

            basecall_time = time.time() - start_time
            bpreads = [index2base(read) for read in reads]
            js_ratio = FLAGS.jump/FLAGS.segment_len
            if FLAGS.extension == 'fastq':
                consensus, qs_consensus = simple_assembly_qs(bpreads, qs_list,js_ratio)
                qs_string = qs(consensus, qs_consensus)
            else:
                consensus = simple_assembly(bpreads,js_ratio)
            c_bpread = index2base(np.argmax(consensus, axis=0))
            assembly_time = time.time() - start_time
            list_of_time = [start_time, reading_time,
                            basecall_time, assembly_time]
            write_output(bpreads, c_bpread, list_of_time, file_pre, concise=FLAGS.concise, suffix=FLAGS.extension,
                         q_score=qs_string,global_setting=FLAGS)
    pbars.end()
Example #13
0
        action='store_true',
        help=
        "Concisely output the result, the meta and segments files will not be output."
    )
    parser.add_argument('--mode',
                        default='dna',
                        help="Output mode, can be chosen from dna or rna.")
    FLAGS = parser.parse_args(sys.argv[1:])

    pbars = multi_pbars(
        ["Logits(batches)", "ctc(batches)", "logits(files)", "ctc(files)"])
    x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.segment_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    training = tf.placeholder(tf.bool)
    config_path = os.path.join(FLAGS.model, 'model.json')
    model_configure = chiron_model.read_config(config_path)
    logits, ratio = chiron_model.inference(x,
                                           seq_length,
                                           training=training,
                                           full_sequence_len=FLAGS.segment_len,
                                           configure=model_configure)
    prorbs = tf.nn.softmax(logits)
    predict = tf.nn.ctc_beam_search_decoder(tf.transpose(logits,
                                                         perm=[1, 0, 2]),
                                            seq_length,
                                            merge_repeated=False,
                                            beam_width=FLAGS.beam)
    config = tf.ConfigProto(allow_soft_placement=True,
                            intra_op_parallelism_threads=FLAGS.threads,
                            inter_op_parallelism_threads=FLAGS.threads)
    config.gpu_options.allow_growth = True