def test_cudapoa_complex_batch():
    random.seed(2)
    read_len = 500
    ref = ''.join(
        [random.choice(['A', 'C', 'G', 'T']) for _ in range(read_len)])
    num_reads = 100
    mutation_rate = 0.02
    reads = []
    for _ in range(num_reads):
        new_read = ''.join([
            r if random.random() > mutation_rate else random.choice(
                ['A', 'C', 'G', 'T']) for r in ref
        ])
        reads.append(new_read)

    device = cuda.cuda_get_device()
    free, total = cuda.cuda_get_mem_info(device)
    stream = cuda.CudaStream()
    batch = CudaPoaBatch(1000,
                         1024,
                         0.9 * free,
                         stream=stream,
                         device_id=device)
    (add_status, seq_status) = batch.add_poa_group(reads)
    batch.generate_poa()

    consensus, coverage, status = batch.get_consensus()

    consensus = consensus[0]
    assert (len(consensus) == len(ref))

    match_ratio = SequenceMatcher(None, ref, consensus).ratio()
    assert (match_ratio == 1.0)
Esempio n. 2
0
def poagen(groups, gpu_percent=0.8):
    free, total = cuda.cuda_get_mem_info(cuda.cuda_get_device())
    gpu_mem_per_batch = gpu_percent * free

    max_seq_sz = 0
    max_sequences_per_poa = 0

    for group in groups:
        longest_seq = len(max(group, key=len))
        max_seq_sz = longest_seq if longest_seq > max_seq_sz else max_seq_sz
        seq_in_poa = len(group)
        max_sequences_per_poa = seq_in_poa if seq_in_poa > max_sequences_per_poa else max_sequences_per_poa

    batch = CudaPoaBatch(
        max_sequences_per_poa,
        max_seq_sz,
        gpu_mem_per_batch,
        output_type="consensus",
        cuda_banded_alignment=True,
        alignment_band_width=256,
    )

    poa_index = 0
    initial_count = 0

    while poa_index < len(groups):

        group = groups[poa_index]
        group_status, seq_status = batch.add_poa_group(group)

        # If group was added and more space is left in batch, continue onto next group.
        if group_status == 0:
            for seq_index, status in enumerate(seq_status):
                if status != 0:
                    print("Could not add sequence {} to POA {} - error {}".
                          format(seq_index, poa_index, status_to_str(status)),
                          file=sys.stderr)
            poa_index += 1

        # Once batch is full or no groups are left, run POA processing.
        if ((group_status == 1)
                or ((group_status == 0) and (poa_index == len(groups)))):
            batch.generate_poa()
            consensus, coverage, con_status = batch.get_consensus()
            for p, status in enumerate(con_status):
                if status != 0:
                    print(
                        "Could not get consensus for POA group {} - {}".format(
                            initial_count + p, status_to_str(status)),
                        file=sys.stderr)
            yield from consensus
            initial_count = poa_index
            batch.reset()

        # In the case where POA group wasn't processed correctly.
        elif group_status != 0:
            print("Could not add POA group {} to batch - {}".format(
                poa_index, status_to_str(group_status)),
                  file=sys.stderr)
            poa_index += 1
def test_cudapoa_valid_output_type():
    device = cuda.cuda_get_device()
    free, total = cuda.cuda_get_mem_info(device)
    try:
        CudaPoaBatch(10,
                     1024,
                     0.9 * free,
                     deivce_id=device,
                     output_type='consensus')
    except RuntimeError:
        assert (False)
def test_cudapoa_incorrect_output_type():
    device = cuda.cuda_get_device()
    free, total = cuda.cuda_get_mem_info(device)
    try:
        CudaPoaBatch(10,
                     1024,
                     0.9 * free,
                     deivce_id=device,
                     output_type='error_input')
        assert (False)
    except RuntimeError:
        pass
def test_cudapoa_reset_batch():
    device = cuda.cuda_get_device()
    free, total = cuda.cuda_get_mem_info(device)
    batch = CudaPoaBatch(10, 1024, 0.9 * free, device_id=device)
    poa_1 = ["ACTGACTG", "ACTTACTG", "ACGGACTG", "ATCGACTG"]
    batch.add_poa_group(poa_1)
    batch.generate_poa()
    consensus, coverage, status = batch.get_consensus()

    assert (batch.total_poas == 1)

    batch.reset()

    assert (batch.total_poas == 0)
def test_cudapoa_simple_batch():
    device = cuda.cuda_get_device()
    free, total = cuda.cuda_get_mem_info(device)
    batch = CudaPoaBatch(10,
                         1024,
                         0.9 * free,
                         deivce_id=device,
                         output_mask='consensus')
    poa_1 = ["ACTGACTG", "ACTTACTG", "ACGGACTG", "ATCGACTG"]
    poa_2 = ["ACTGAC", "ACTTAC", "ACGGAC", "ATCGAC"]
    batch.add_poa_group(poa_1)
    batch.add_poa_group(poa_2)
    batch.generate_poa()
    consensus, coverage, status = batch.get_consensus()

    assert (len(consensus) == 2)
    assert (batch.total_poas == 2)
def test_cudapoa_graph():
    device = cuda.cuda_get_device()
    free, total = cuda.cuda_get_mem_info(device)
    batch = CudaPoaBatch(10, 1024, 0.9 * free, device_id=device)
    poa_1 = ["ACTGACTG", "ACTTACTG", "ACTCACTG"]
    batch.add_poa_group(poa_1)
    batch.generate_poa()
    consensus, coverage, status = batch.get_consensus()

    assert (batch.total_poas == 1)

    # Expected graph
    #           - -> G -> -
    #           |         |
    # A -> C -> T -> T -> A -> C -> T -> G
    #           |         |
    #           - -> C -> -

    graphs, status = batch.get_graphs()
    assert (len(graphs) == 1)

    digraph = graphs[0]
    assert (digraph.number_of_nodes() == 10)
    assert (digraph.number_of_edges() == 11)
Esempio n. 8
0
def main(args):

    sys.stderr.write("> loading model\n")
    model = load_model(args.model, args.device)

    if args.reference:
        sys.stderr.write("> loading reference\n")
        aligner = Aligner(args.reference, preset='ont-map')
        if not aligner:
            sys.stderr.write("> failed to load/build index\n")
            exit(1)
    else:
        aligner = None

    if args.summary:
        sys.stderr.write("> finding follow on strands\n")
        pairs = pd.read_csv(args.summary, '\t', low_memory=False)
        pairs = pairs[pairs.sequence_length_template.gt(0)]
        if 'filename' in pairs.columns:
            pairs = pairs.rename(columns={'filename': 'filename_fast5'})
        if 'alignment_strand_coverage' in pairs.columns:
            pairs = pairs.rename(
                columns={'alignment_strand_coverage': 'alignment_coverage'})
        valid_fast5s = [
            f for f in pairs.filename_fast5.unique()
            if ((args.reads_directory / Path(f)).exists())
        ]
        pairs = pairs[pairs.filename_fast5.isin(valid_fast5s)]
        pairs = find_follow_on(pairs)
        sys.stderr.write("> found %s follow strands in summary\n" %
                         (len(pairs) // 2))

        if args.max_reads > 0: pairs = pairs.head(args.max_reads)

        temp_reads = pairs.iloc[0::2]
        comp_reads = pairs.iloc[1::2]
    else:
        if args.index is not None:
            sys.stderr.write("> loading read index\n")
            index = json.load(open(args.index, 'r'))
        else:
            sys.stderr.write("> building read index\n")
            files = list(glob(os.path.join(args.reads_directory, '*.fast5')))
            index = build_index(files, n_proc=8)
            if args.save_index:
                with open('bonito-read-id.idx', 'w') as f:
                    json.dump(index, f)

        pairs = pd.read_csv(args.pairs,
                            sep=args.sep,
                            names=['read_1', 'read_2'])
        if args.max_reads > 0: pairs = pairs.head(args.max_reads)

        pairs['file_1'] = pairs['read_1'].apply(index.get)
        pairs['file_2'] = pairs['read_2'].apply(index.get)
        pairs = pairs.dropna().reset_index()

        temp_reads = pairs[['read_1',
                            'file_1']].rename(columns={
                                'read_1': 'read_id',
                                'file_1': 'filename_fast5'
                            })
        comp_reads = pairs[['read_2',
                            'file_2']].rename(columns={
                                'read_2': 'read_id',
                                'file_2': 'filename_fast5'
                            })

    if len(pairs) == 0:
        print("> no matched pairs found in given directory", file=sys.stderr)
        exit(1)

    # https://github.com/clara-parabricks/GenomeWorks/issues/648
    with devnull():
        CudaPoaBatch(1000, 1000, 3724032)

    basecalls = call(model,
                     args.reads_directory,
                     temp_reads,
                     comp_reads,
                     aligner=aligner)
    writer = Writer(tqdm(basecalls,
                         desc="> calling",
                         unit=" reads",
                         leave=False),
                    aligner,
                    duplex=True)

    t0 = perf_counter()
    writer.start()
    writer.join()
    duration = perf_counter() - t0
    num_samples = sum(num_samples for read_id, num_samples in writer.log)

    print("> duration: %s" % timedelta(seconds=np.round(duration)),
          file=sys.stderr)
    print("> samples per second %.1E" % (num_samples / duration),
          file=sys.stderr)